diff options
30 files changed, 486 insertions, 286 deletions
diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 000000000..706ab462f --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,57 @@ +name: Build + +on: + workflow_run: + workflows: ["Lint & Test"] + branches: + - master + types: + - completed + +jobs: + build: + if: github.event.workflow_run.conclusion == 'success' + name: Build & Push + runs-on: ubuntu-latest + + steps: + # Create a commit SHA-based tag for the container repositories + - name: Create SHA Container Tag + id: sha_tag + run: | + tag=$(cut -c 1-7 <<< $GITHUB_SHA) + echo "::set-output name=tag::$tag" + + - name: Checkout code + uses: actions/checkout@v2 + + # The current version (v2) of Docker's build-push action uses + # buildx, which comes with BuildKit features that help us speed + # up our builds using additional cache features. Buildx also + # has a lot of other features that are not as relevant to us. + # + # See https://github.com/docker/build-push-action + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v1 + + - name: Login to Github Container Registry + uses: docker/login-action@v1 + with: + registry: ghcr.io + username: ${{ github.repository_owner }} + password: ${{ secrets.GHCR_TOKEN }} + + # Build and push the container to the GitHub Container + # Repository. The container will be tagged as "latest" + # and with the short SHA of the commit. + - name: Build and push + uses: docker/build-push-action@v2 + with: + context: . + file: ./Dockerfile + push: true + cache-from: type=registry,ref=ghcr.io/python-discord/bot:latest + cache-to: type=inline + tags: | + ghcr.io/python-discord/bot:latest + ghcr.io/python-discord/bot:${{ steps.sha_tag.outputs.tag }} diff --git a/.github/workflows/lint-test.yml b/.github/workflows/lint-test.yml new file mode 100644 index 000000000..5444fc3de --- /dev/null +++ b/.github/workflows/lint-test.yml @@ -0,0 +1,115 @@ +name: Lint & Test + +on: + push: + branches: + - master + pull_request: + + +jobs: + lint-test: + runs-on: ubuntu-latest + env: + # Dummy values for required bot environment variables + BOT_API_KEY: foo + BOT_SENTRY_DSN: blah + BOT_TOKEN: bar + REDDIT_CLIENT_ID: spam + REDDIT_SECRET: ham + REDIS_PASSWORD: '' + + # Configure pip to cache dependencies and do a user install + PIP_NO_CACHE_DIR: false + PIP_USER: 1 + + # Hide the graphical elements from pipenv's output + PIPENV_HIDE_EMOJIS: 1 + PIPENV_NOSPIN: 1 + + # Make sure pipenv does not try reuse an environment it's running in + PIPENV_IGNORE_VIRTUALENVS: 1 + + # Specify explicit paths for python dependencies and the pre-commit + # environment so we know which directories to cache + PYTHONUSERBASE: ${{ github.workspace }}/.cache/py-user-base + PRE_COMMIT_HOME: ${{ github.workspace }}/.cache/pre-commit-cache + + steps: + - name: Add custom PYTHONUSERBASE to PATH + run: echo '${{ env.PYTHONUSERBASE }}/bin/' >> $GITHUB_PATH + + - name: Checkout repository + uses: actions/checkout@v2 + + - name: Setup python + id: python + uses: actions/setup-python@v2 + with: + python-version: '3.8' + + # This step caches our Python dependencies. To make sure we + # only restore a cache when the dependencies, the python version, + # the runner operating system, and the dependency location haven't + # changed, we create a cache key that is a composite of those states. + # + # Only when the context is exactly the same, we will restore the cache. + - name: Python Dependency Caching + uses: actions/cache@v2 + id: python_cache + with: + path: ${{ env.PYTHONUSERBASE }} + key: "python-0-${{ runner.os }}-${{ env.PYTHONUSERBASE }}-\ + ${{ steps.python.outputs.python-version }}-\ + ${{ hashFiles('./Pipfile', './Pipfile.lock') }}" + + # Install our dependencies if we did not restore a dependency cache + - name: Install dependencies using pipenv + if: steps.python_cache.outputs.cache-hit != 'true' + run: | + pip install pipenv + pipenv install --dev --deploy --system + + # This step caches our pre-commit environment. To make sure we + # do create a new environment when our pre-commit setup changes, + # we create a cache key based on relevant factors. + - name: Pre-commit Environment Caching + uses: actions/cache@v2 + with: + path: ${{ env.PRE_COMMIT_HOME }} + key: "precommit-0-${{ runner.os }}-${{ env.PRE_COMMIT_HOME }}-\ + ${{ steps.python.outputs.python-version }}-\ + ${{ hashFiles('./.pre-commit-config.yaml') }}" + + # We will not run `flake8` here, as we will use a separate flake8 + # action. As pre-commit does not support user installs, we set + # PIP_USER=0 to not do a user install. + - name: Run pre-commit hooks + run: export PIP_USER=0; SKIP=flake8 pre-commit run --all-files + + # Run flake8 and have it format the linting errors in the format of + # the GitHub Workflow command to register error annotations. This + # means that our flake8 output is automatically added as an error + # annotation to both the run result and in the "Files" tab of a + # pull request. + # + # Format used: + # ::error file={filename},line={line},col={col}::{message} + - name: Run flake8 + run: "flake8 \ + --format='::error file=%(path)s,line=%(row)d,col=%(col)d::\ + [flake8] %(code)s: %(text)s'" + + # We run `coverage` using the `python` command so we can suppress + # irrelevant warnings in our CI output. + - name: Run tests and generate coverage report + run: | + python -Wignore -m coverage run -m unittest + coverage report -m + + # This step will publish the coverage reports coveralls.io and + # print a "job" link in the output of the GitHub Action + - name: Publish coverage report to coveralls.io + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: coveralls diff --git a/.gitignore b/.gitignore index 2074887ad..9186dbe06 100644 --- a/.gitignore +++ b/.gitignore @@ -111,6 +111,7 @@ ENV/ # Logfiles log.* *.log.* +!log.py # Custom user configuration config.yml @@ -48,8 +48,8 @@ python_version = "3.8" start = "python -m bot" lint = "pre-commit run --all-files" precommit = "pre-commit install" -build = "docker build -t pythondiscord/bot:latest -f Dockerfile ." -push = "docker push pythondiscord/bot:latest" +build = "docker build -t ghcr.io/python-discord/bot:latest -f Dockerfile ." +push = "docker push ghcr.io/python-discord/bot:latest" test = "coverage run -m unittest" html = "coverage html" report = "coverage report" @@ -1,7 +1,8 @@ # Python Utility Bot [](https://discord.gg/2B963hn) - +[![Lint & Test][1]][2] +[![Build][3]][4] [](https://coveralls.io/github/python-discord/bot) [](LICENSE) [](https://pythondiscord.com) @@ -10,3 +11,8 @@ This project is a Discord bot specifically for use with the Python Discord serve and other tools to help keep the server running like a well-oiled machine. Read the [Contributing Guide](https://pythondiscord.com/pages/contributing/bot/) on our website if you're interested in helping out. + +[1]: https://github.com/python-discord/bot/workflows/Lint%20&%20Test/badge.svg?branch=master +[2]: https://github.com/python-discord/bot/actions?query=workflow%3A%22Lint+%26+Test%22+branch%3Amaster +[3]: https://github.com/python-discord/bot/workflows/Build/badge.svg?branch=master +[4]: https://github.com/python-discord/bot/actions?query=workflow%3ABuild+branch%3Amaster diff --git a/bot/__init__.py b/bot/__init__.py index 4fce04532..8f880b8e6 100644 --- a/bot/__init__.py +++ b/bot/__init__.py @@ -1,78 +1,25 @@ import asyncio -import logging import os -import sys from functools import partial, partialmethod -from logging import Logger, handlers -from pathlib import Path +from typing import TYPE_CHECKING -import coloredlogs from discord.ext import commands +from bot import log from bot.command import Command -TRACE_LEVEL = logging.TRACE = 5 -logging.addLevelName(TRACE_LEVEL, "TRACE") - - -def monkeypatch_trace(self: logging.Logger, msg: str, *args, **kwargs) -> None: - """ - Log 'msg % args' with severity 'TRACE'. - - To pass exception information, use the keyword argument exc_info with - a true value, e.g. - - logger.trace("Houston, we have an %s", "interesting problem", exc_info=1) - """ - if self.isEnabledFor(TRACE_LEVEL): - self._log(TRACE_LEVEL, msg, args, **kwargs) - - -Logger.trace = monkeypatch_trace - -DEBUG_MODE = 'local' in os.environ.get("SITE_URL", "local") - -log_level = TRACE_LEVEL if DEBUG_MODE else logging.INFO -format_string = "%(asctime)s | %(name)s | %(levelname)s | %(message)s" -log_format = logging.Formatter(format_string) - -log_file = Path("logs", "bot.log") -log_file.parent.mkdir(exist_ok=True) -file_handler = handlers.RotatingFileHandler(log_file, maxBytes=5242880, backupCount=7, encoding="utf8") -file_handler.setFormatter(log_format) - -root_log = logging.getLogger() -root_log.setLevel(log_level) -root_log.addHandler(file_handler) - -if "COLOREDLOGS_LEVEL_STYLES" not in os.environ: - coloredlogs.DEFAULT_LEVEL_STYLES = { - **coloredlogs.DEFAULT_LEVEL_STYLES, - "trace": {"color": 246}, - "critical": {"background": "red"}, - "debug": coloredlogs.DEFAULT_LEVEL_STYLES["info"] - } - -if "COLOREDLOGS_LOG_FORMAT" not in os.environ: - coloredlogs.DEFAULT_LOG_FORMAT = format_string - -if "COLOREDLOGS_LOG_LEVEL" not in os.environ: - coloredlogs.DEFAULT_LOG_LEVEL = log_level - -coloredlogs.install(logger=root_log, stream=sys.stdout) - -logging.getLogger("discord").setLevel(logging.WARNING) -logging.getLogger("websockets").setLevel(logging.WARNING) -logging.getLogger("chardet").setLevel(logging.WARNING) -logging.getLogger("async_rediscache").setLevel(logging.WARNING) +if TYPE_CHECKING: + from bot.bot import Bot +log.setup() # On Windows, the selector event loop is required for aiodns. if os.name == "nt": asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - # Monkey-patch discord.py decorators to use the Command subclass which supports root aliases. # Must be patched before any cogs are added. commands.command = partial(commands.command, cls=Command) commands.GroupMixin.command = partialmethod(commands.GroupMixin.command, cls=Command) + +instance: "Bot" = None # Global Bot instance. diff --git a/bot/__main__.py b/bot/__main__.py index 367be1300..257216fa7 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -1,76 +1,10 @@ -import asyncio -import logging - -import discord -import sentry_sdk -from async_rediscache import RedisSession -from discord.ext.commands import when_mentioned_or -from sentry_sdk.integrations.aiohttp import AioHttpIntegration -from sentry_sdk.integrations.logging import LoggingIntegration -from sentry_sdk.integrations.redis import RedisIntegration - +import bot from bot import constants from bot.bot import Bot -from bot.utils.extensions import EXTENSIONS - -# Set up Sentry. -sentry_logging = LoggingIntegration( - level=logging.DEBUG, - event_level=logging.WARNING -) - -sentry_sdk.init( - dsn=constants.Bot.sentry_dsn, - integrations=[ - sentry_logging, - AioHttpIntegration(), - RedisIntegration(), - ] -) - -# Create the redis session instance. -redis_session = RedisSession( - address=(constants.Redis.host, constants.Redis.port), - password=constants.Redis.password, - minsize=1, - maxsize=20, - use_fakeredis=constants.Redis.use_fakeredis, - global_namespace="bot", -) - -# Connect redis session to ensure it's connected before we try to access Redis -# from somewhere within the bot. We create the event loop in the same way -# discord.py normally does and pass it to the bot's __init__. -loop = asyncio.get_event_loop() -loop.run_until_complete(redis_session.connect()) - - -# Instantiate the bot. -allowed_roles = [discord.Object(id_) for id_ in constants.MODERATION_ROLES] -intents = discord.Intents().all() -intents.presences = False -intents.dm_typing = False -intents.dm_reactions = False -intents.invites = False -intents.webhooks = False -intents.integrations = False -bot = Bot( - redis_session=redis_session, - loop=loop, - command_prefix=when_mentioned_or(constants.Bot.prefix), - activity=discord.Game(name=f"Commands: {constants.Bot.prefix}help"), - case_insensitive=True, - max_messages=10_000, - allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles), - intents=intents, -) - -# Load extensions. -extensions = set(EXTENSIONS) # Create a mutable copy. -if not constants.HelpChannels.enable: - extensions.remove("bot.exts.help_channels") +from bot.log import setup_sentry -for extension in extensions: - bot.load_extension(extension) +setup_sentry() -bot.run(constants.Bot.token) +bot.instance = Bot.create() +bot.instance.load_extensions() +bot.instance.run(constants.Bot.token) diff --git a/bot/bot.py b/bot/bot.py index b2e5237fe..36cf7d30a 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -11,7 +11,7 @@ from async_rediscache import RedisSession from discord.ext import commands from sentry_sdk import push_scope -from bot import DEBUG_MODE, api, constants +from bot import api, constants from bot.async_stats import AsyncStatsClient log = logging.getLogger('bot') @@ -40,7 +40,7 @@ class Bot(commands.Bot): statsd_url = constants.Stats.statsd_host - if DEBUG_MODE: + if constants.DEBUG_MODE: # Since statsd is UDP, there are no errors for sending to a down port. # For this reason, setting the statsd host to 127.0.0.1 for development # will effectively disable stats. @@ -95,6 +95,43 @@ class Bot(commands.Bot): # Build the FilterList cache self.loop.create_task(self.cache_filter_list_data()) + @classmethod + def create(cls) -> "Bot": + """Create and return an instance of a Bot.""" + loop = asyncio.get_event_loop() + allowed_roles = [discord.Object(id_) for id_ in constants.MODERATION_ROLES] + + intents = discord.Intents().all() + intents.presences = False + intents.dm_typing = False + intents.dm_reactions = False + intents.invites = False + intents.webhooks = False + intents.integrations = False + + return cls( + redis_session=_create_redis_session(loop), + loop=loop, + command_prefix=commands.when_mentioned_or(constants.Bot.prefix), + activity=discord.Game(name=f"Commands: {constants.Bot.prefix}help"), + case_insensitive=True, + max_messages=10_000, + allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles), + intents=intents, + ) + + def load_extensions(self) -> None: + """Load all enabled extensions.""" + # Must be done here to avoid a circular import. + from bot.utils.extensions import EXTENSIONS + + extensions = set(EXTENSIONS) # Create a mutable copy. + if not constants.HelpChannels.enable: + extensions.remove("bot.exts.help_channels") + + for extension in extensions: + self.load_extension(extension) + def add_cog(self, cog: commands.Cog) -> None: """Adds a "cog" to the bot and logs the operation.""" super().add_cog(cog) @@ -243,3 +280,22 @@ class Bot(commands.Bot): for alias in getattr(command, "root_aliases", ()): self.all_commands.pop(alias, None) + + +def _create_redis_session(loop: asyncio.AbstractEventLoop) -> RedisSession: + """ + Create and connect to a redis session. + + Ensure the connection is established before returning to prevent race conditions. + `loop` is the event loop on which to connect. The Bot should use this same event loop. + """ + redis_session = RedisSession( + address=(constants.Redis.host, constants.Redis.port), + password=constants.Redis.password, + minsize=1, + maxsize=20, + use_fakeredis=constants.Redis.use_fakeredis, + global_namespace="bot", + ) + loop.run_until_complete(redis_session.connect()) + return redis_session diff --git a/bot/constants.py b/bot/constants.py index 731f06fed..2126b2b37 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -632,7 +632,7 @@ class Event(Enum): # Debug mode -DEBUG_MODE = True if 'local' in os.environ.get("SITE_URL", "local") else False +DEBUG_MODE = 'local' in os.environ.get("SITE_URL", "local") # Paths BOT_DIR = os.path.dirname(__file__) diff --git a/bot/exts/backend/sync/_cog.py b/bot/exts/backend/sync/_cog.py index 6e85e2b7d..48d2b6f02 100644 --- a/bot/exts/backend/sync/_cog.py +++ b/bot/exts/backend/sync/_cog.py @@ -18,9 +18,6 @@ class Sync(Cog): def __init__(self, bot: Bot) -> None: self.bot = bot - self.role_syncer = _syncers.RoleSyncer(self.bot) - self.user_syncer = _syncers.UserSyncer(self.bot) - self.bot.loop.create_task(self.sync_guild()) async def sync_guild(self) -> None: @@ -31,7 +28,7 @@ class Sync(Cog): if guild is None: return - for syncer in (self.role_syncer, self.user_syncer): + for syncer in (_syncers.RoleSyncer, _syncers.UserSyncer): await syncer.sync(guild) async def patch_user(self, user_id: int, json: Dict[str, Any], ignore_404: bool = False) -> None: @@ -171,10 +168,10 @@ class Sync(Cog): @commands.has_permissions(administrator=True) async def sync_roles_command(self, ctx: Context) -> None: """Manually synchronise the guild's roles with the roles on the site.""" - await self.role_syncer.sync(ctx.guild, ctx) + await _syncers.RoleSyncer.sync(ctx.guild, ctx) @sync_group.command(name='users') @commands.has_permissions(administrator=True) async def sync_users_command(self, ctx: Context) -> None: """Manually synchronise the guild's users with the users on the site.""" - await self.user_syncer.sync(ctx.guild, ctx) + await _syncers.UserSyncer.sync(ctx.guild, ctx) diff --git a/bot/exts/backend/sync/_syncers.py b/bot/exts/backend/sync/_syncers.py index 38468c2b1..2eb9f9971 100644 --- a/bot/exts/backend/sync/_syncers.py +++ b/bot/exts/backend/sync/_syncers.py @@ -6,8 +6,8 @@ from collections import namedtuple from discord import Guild from discord.ext.commands import Context +import bot from bot.api import ResponseCodeError -from bot.bot import Bot log = logging.getLogger(__name__) @@ -17,57 +17,60 @@ _Role = namedtuple('Role', ('id', 'name', 'colour', 'permissions', 'position')) _Diff = namedtuple('Diff', ('created', 'updated', 'deleted')) +# Implementation of static abstract methods are not enforced if the subclass is never instantiated. +# However, methods are kept abstract to at least symbolise that they should be abstract. class Syncer(abc.ABC): """Base class for synchronising the database with objects in the Discord cache.""" - def __init__(self, bot: Bot) -> None: - self.bot = bot - + @staticmethod @property @abc.abstractmethod - def name(self) -> str: + def name() -> str: """The name of the syncer; used in output messages and logging.""" raise NotImplementedError # pragma: no cover + @staticmethod @abc.abstractmethod - async def _get_diff(self, guild: Guild) -> _Diff: + async def _get_diff(guild: Guild) -> _Diff: """Return the difference between the cache of `guild` and the database.""" raise NotImplementedError # pragma: no cover + @staticmethod @abc.abstractmethod - async def _sync(self, diff: _Diff) -> None: + async def _sync(diff: _Diff) -> None: """Perform the API calls for synchronisation.""" raise NotImplementedError # pragma: no cover - async def sync(self, guild: Guild, ctx: t.Optional[Context] = None) -> None: + @classmethod + async def sync(cls, guild: Guild, ctx: t.Optional[Context] = None) -> None: """ Synchronise the database with the cache of `guild`. If `ctx` is given, send a message with the results. """ - log.info(f"Starting {self.name} syncer.") + log.info(f"Starting {cls.name} syncer.") if ctx: - message = await ctx.send(f"📊 Synchronising {self.name}s.") + message = await ctx.send(f"📊 Synchronising {cls.name}s.") else: message = None - diff = await self._get_diff(guild) + diff = await cls._get_diff(guild) try: - await self._sync(diff) + await cls._sync(diff) except ResponseCodeError as e: - log.exception(f"{self.name} syncer failed!") + log.exception(f"{cls.name} syncer failed!") # Don't show response text because it's probably some really long HTML. results = f"status {e.status}\n```{e.response_json or 'See log output for details'}```" - content = f":x: Synchronisation of {self.name}s failed: {results}" + content = f":x: Synchronisation of {cls.name}s failed: {results}" else: diff_dict = diff._asdict() results = (f"{name} `{len(val)}`" for name, val in diff_dict.items() if val is not None) results = ", ".join(results) - log.info(f"{self.name} syncer finished: {results}.") - content = f":ok_hand: Synchronisation of {self.name}s complete: {results}" + log.info(f"{cls.name} syncer finished: {results}.") + content = f":ok_hand: Synchronisation of {cls.name}s complete: {results}" if message: await message.edit(content=content) @@ -78,10 +81,11 @@ class RoleSyncer(Syncer): name = "role" - async def _get_diff(self, guild: Guild) -> _Diff: + @staticmethod + async def _get_diff(guild: Guild) -> _Diff: """Return the difference of roles between the cache of `guild` and the database.""" log.trace("Getting the diff for roles.") - roles = await self.bot.api_client.get('bot/roles') + roles = await bot.instance.api_client.get('bot/roles') # Pack DB roles and guild roles into one common, hashable format. # They're hashable so that they're easily comparable with sets later. @@ -110,19 +114,20 @@ class RoleSyncer(Syncer): return _Diff(roles_to_create, roles_to_update, roles_to_delete) - async def _sync(self, diff: _Diff) -> None: + @staticmethod + async def _sync(diff: _Diff) -> None: """Synchronise the database with the role cache of `guild`.""" log.trace("Syncing created roles...") for role in diff.created: - await self.bot.api_client.post('bot/roles', json=role._asdict()) + await bot.instance.api_client.post('bot/roles', json=role._asdict()) log.trace("Syncing updated roles...") for role in diff.updated: - await self.bot.api_client.put(f'bot/roles/{role.id}', json=role._asdict()) + await bot.instance.api_client.put(f'bot/roles/{role.id}', json=role._asdict()) log.trace("Syncing deleted roles...") for role in diff.deleted: - await self.bot.api_client.delete(f'bot/roles/{role.id}') + await bot.instance.api_client.delete(f'bot/roles/{role.id}') class UserSyncer(Syncer): @@ -130,7 +135,8 @@ class UserSyncer(Syncer): name = "user" - async def _get_diff(self, guild: Guild) -> _Diff: + @staticmethod + async def _get_diff(guild: Guild) -> _Diff: """Return the difference of users between the cache of `guild` and the database.""" log.trace("Getting the diff for users.") @@ -138,7 +144,7 @@ class UserSyncer(Syncer): users_to_update = [] seen_guild_users = set() - async for db_user in self._get_users(): + async for db_user in UserSyncer._get_users(): # Store user fields which are to be updated. updated_fields = {} @@ -185,24 +191,26 @@ class UserSyncer(Syncer): return _Diff(users_to_create, users_to_update, None) - async def _get_users(self) -> t.AsyncIterable: + @staticmethod + async def _get_users() -> t.AsyncIterable: """GET users from database.""" query_params = { "page": 1 } while query_params["page"]: - res = await self.bot.api_client.get("bot/users", params=query_params) + res = await bot.instance.api_client.get("bot/users", params=query_params) for user in res["results"]: yield user query_params["page"] = res["next_page_no"] - async def _sync(self, diff: _Diff) -> None: + @staticmethod + async def _sync(diff: _Diff) -> None: """Synchronise the database with the user cache of `guild`.""" log.trace("Syncing created users...") if diff.created: - await self.bot.api_client.post("bot/users", json=diff.created) + await bot.instance.api_client.post("bot/users", json=diff.created) log.trace("Syncing updated users...") if diff.updated: - await self.bot.api_client.patch("bot/users/bulk_patch", json=diff.updated) + await bot.instance.api_client.patch("bot/users/bulk_patch", json=diff.updated) diff --git a/bot/exts/help_channels.py b/bot/exts/help_channels.py index 062d4fcfe..f5a8b251b 100644 --- a/bot/exts/help_channels.py +++ b/bot/exts/help_channels.py @@ -380,16 +380,13 @@ class HelpChannels(commands.Cog): try: self.available_category = await channel_utils.try_get_channel( - constants.Categories.help_available, - self.bot + constants.Categories.help_available ) self.in_use_category = await channel_utils.try_get_channel( - constants.Categories.help_in_use, - self.bot + constants.Categories.help_in_use ) self.dormant_category = await channel_utils.try_get_channel( - constants.Categories.help_dormant, - self.bot + constants.Categories.help_dormant ) except discord.HTTPException: log.exception("Failed to get a category; cog will be removed") @@ -500,7 +497,7 @@ class HelpChannels(commands.Cog): options should be avoided, as it may interfere with the category move we perform. """ # Get a fresh copy of the category from the bot to avoid the cache mismatch issue we had. - category = await channel_utils.try_get_channel(category_id, self.bot) + category = await channel_utils.try_get_channel(category_id) payload = [{"id": c.id, "position": c.position} for c in category.channels] diff --git a/bot/exts/info/codeblock/_cog.py b/bot/exts/info/codeblock/_cog.py index 1e0feab0d..9094d9d15 100644 --- a/bot/exts/info/codeblock/_cog.py +++ b/bot/exts/info/codeblock/_cog.py @@ -114,7 +114,7 @@ class CodeBlockCog(Cog, name="Code Block"): bot_message = await message.channel.send(f"Hey {message.author.mention}!", embed=embed) self.codeblock_message_ids[message.id] = bot_message.id - self.bot.loop.create_task(wait_for_deletion(bot_message, (message.author.id,), self.bot)) + self.bot.loop.create_task(wait_for_deletion(bot_message, (message.author.id,))) # Increase amount of codeblock correction in stats self.bot.stats.incr("codeblock_corrections") diff --git a/bot/exts/info/doc.py b/bot/exts/info/doc.py index 7ec8caa4b..9b5bd6504 100644 --- a/bot/exts/info/doc.py +++ b/bot/exts/info/doc.py @@ -365,7 +365,7 @@ class Doc(commands.Cog): await ctx.message.delete(delay=NOT_FOUND_DELETE_DELAY) else: msg = await ctx.send(embed=doc_embed) - await wait_for_deletion(msg, (ctx.author.id,), client=self.bot) + await wait_for_deletion(msg, (ctx.author.id,)) @docs_group.command(name='set', aliases=('s',)) @commands.has_any_role(*MODERATION_ROLES) diff --git a/bot/exts/info/help.py b/bot/exts/info/help.py index 599c5d5c0..461ff82fd 100644 --- a/bot/exts/info/help.py +++ b/bot/exts/info/help.py @@ -186,7 +186,7 @@ class CustomHelpCommand(HelpCommand): """Send help for a single command.""" embed = await self.command_formatting(command) message = await self.context.send(embed=embed) - await wait_for_deletion(message, (self.context.author.id,), self.context.bot) + await wait_for_deletion(message, (self.context.author.id,)) @staticmethod def get_commands_brief_details(commands_: List[Command], return_as_list: bool = False) -> Union[List[str], str]: @@ -225,7 +225,7 @@ class CustomHelpCommand(HelpCommand): embed.description += f"\n**Subcommands:**\n{command_details}" message = await self.context.send(embed=embed) - await wait_for_deletion(message, (self.context.author.id,), self.context.bot) + await wait_for_deletion(message, (self.context.author.id,)) async def send_cog_help(self, cog: Cog) -> None: """Send help for a cog.""" @@ -241,7 +241,7 @@ class CustomHelpCommand(HelpCommand): embed.description += f"\n\n**Commands:**\n{command_details}" message = await self.context.send(embed=embed) - await wait_for_deletion(message, (self.context.author.id,), self.context.bot) + await wait_for_deletion(message, (self.context.author.id,)) @staticmethod def _category_key(command: Command) -> str: diff --git a/bot/exts/info/tags.py b/bot/exts/info/tags.py index ae95ac1ef..8f15f932b 100644 --- a/bot/exts/info/tags.py +++ b/bot/exts/info/tags.py @@ -236,7 +236,6 @@ class Tags(Cog): await wait_for_deletion( await ctx.send(embed=Embed.from_dict(tag['embed'])), [ctx.author.id], - self.bot ) elif founds and len(tag_name) >= 3: await wait_for_deletion( @@ -247,7 +246,6 @@ class Tags(Cog): ) ), [ctx.author.id], - self.bot ) else: diff --git a/bot/exts/utils/internal.py b/bot/exts/utils/internal.py index 1b4900f42..3521c8fd4 100644 --- a/bot/exts/utils/internal.py +++ b/bot/exts/utils/internal.py @@ -30,7 +30,7 @@ class Internal(Cog): self.ln = 0 self.stdout = StringIO() - self.interpreter = Interpreter(bot) + self.interpreter = Interpreter() self.socket_since = datetime.utcnow() self.socket_event_total = 0 @@ -195,7 +195,7 @@ async def func(): # (None,) -> Any truncate_index = newline_truncate_index if len(out) > truncate_index: - paste_link = await send_to_paste_service(self.bot.http_session, out, extension="py") + paste_link = await send_to_paste_service(out, extension="py") if paste_link is not None: paste_text = f"full contents at {paste_link}" else: diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index 41cb00541..9f480c067 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -70,7 +70,7 @@ class Snekbox(Cog): if len(output) > MAX_PASTE_LEN: log.info("Full output is too long to upload") return "too long to upload" - return await send_to_paste_service(self.bot.http_session, output, extension="txt") + return await send_to_paste_service(output, extension="txt") @staticmethod def prepare_input(code: str) -> str: @@ -219,7 +219,7 @@ class Snekbox(Cog): response = await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.") else: response = await ctx.send(msg) - self.bot.loop.create_task(wait_for_deletion(response, (ctx.author.id,), ctx.bot)) + self.bot.loop.create_task(wait_for_deletion(response, (ctx.author.id,))) log.info(f"{ctx.author}'s job had a return code of {results['returncode']}") return response diff --git a/bot/interpreter.py b/bot/interpreter.py index 8b7268746..b58f7a6b0 100644 --- a/bot/interpreter.py +++ b/bot/interpreter.py @@ -4,7 +4,7 @@ from typing import Any from discord.ext.commands import Context -from bot.bot import Bot +import bot CODE_TEMPLATE = """ async def _func(): @@ -21,8 +21,8 @@ class Interpreter(InteractiveInterpreter): write_callable = None - def __init__(self, bot: Bot): - locals_ = {"bot": bot} + def __init__(self): + locals_ = {"bot": bot.instance} super().__init__(locals_) async def run(self, code: str, ctx: Context, io: StringIO, *args, **kwargs) -> Any: diff --git a/bot/log.py b/bot/log.py new file mode 100644 index 000000000..13141de40 --- /dev/null +++ b/bot/log.py @@ -0,0 +1,86 @@ +import logging +import os +import sys +from logging import Logger, handlers +from pathlib import Path + +import coloredlogs +import sentry_sdk +from sentry_sdk.integrations.aiohttp import AioHttpIntegration +from sentry_sdk.integrations.logging import LoggingIntegration +from sentry_sdk.integrations.redis import RedisIntegration + +from bot import constants + +TRACE_LEVEL = 5 + + +def setup() -> None: + """Set up loggers.""" + logging.TRACE = TRACE_LEVEL + logging.addLevelName(TRACE_LEVEL, "TRACE") + Logger.trace = _monkeypatch_trace + + log_level = TRACE_LEVEL if constants.DEBUG_MODE else logging.INFO + format_string = "%(asctime)s | %(name)s | %(levelname)s | %(message)s" + log_format = logging.Formatter(format_string) + + log_file = Path("logs", "bot.log") + log_file.parent.mkdir(exist_ok=True) + file_handler = handlers.RotatingFileHandler(log_file, maxBytes=5242880, backupCount=7, encoding="utf8") + file_handler.setFormatter(log_format) + + root_log = logging.getLogger() + root_log.setLevel(log_level) + root_log.addHandler(file_handler) + + if "COLOREDLOGS_LEVEL_STYLES" not in os.environ: + coloredlogs.DEFAULT_LEVEL_STYLES = { + **coloredlogs.DEFAULT_LEVEL_STYLES, + "trace": {"color": 246}, + "critical": {"background": "red"}, + "debug": coloredlogs.DEFAULT_LEVEL_STYLES["info"] + } + + if "COLOREDLOGS_LOG_FORMAT" not in os.environ: + coloredlogs.DEFAULT_LOG_FORMAT = format_string + + if "COLOREDLOGS_LOG_LEVEL" not in os.environ: + coloredlogs.DEFAULT_LOG_LEVEL = log_level + + coloredlogs.install(logger=root_log, stream=sys.stdout) + + logging.getLogger("discord").setLevel(logging.WARNING) + logging.getLogger("websockets").setLevel(logging.WARNING) + logging.getLogger("chardet").setLevel(logging.WARNING) + logging.getLogger("async_rediscache").setLevel(logging.WARNING) + + +def setup_sentry() -> None: + """Set up the Sentry logging integrations.""" + sentry_logging = LoggingIntegration( + level=logging.DEBUG, + event_level=logging.WARNING + ) + + sentry_sdk.init( + dsn=constants.Bot.sentry_dsn, + integrations=[ + sentry_logging, + AioHttpIntegration(), + RedisIntegration(), + ] + ) + + +def _monkeypatch_trace(self: logging.Logger, msg: str, *args, **kwargs) -> None: + """ + Log 'msg % args' with severity 'TRACE'. + + To pass exception information, use the keyword argument exc_info with + a true value, e.g. + + logger.trace("Houston, we have an %s", "interesting problem", exc_info=1) + """ + if self.isEnabledFor(TRACE_LEVEL): + self._log(TRACE_LEVEL, msg, args, **kwargs) diff --git a/bot/utils/channel.py b/bot/utils/channel.py index 6bf70bfde..0c072184c 100644 --- a/bot/utils/channel.py +++ b/bot/utils/channel.py @@ -2,6 +2,7 @@ import logging import discord +import bot from bot import constants from bot.constants import Categories @@ -36,14 +37,14 @@ def is_in_category(channel: discord.TextChannel, category_id: int) -> bool: return getattr(channel, "category_id", None) == category_id -async def try_get_channel(channel_id: int, client: discord.Client) -> discord.abc.GuildChannel: +async def try_get_channel(channel_id: int) -> discord.abc.GuildChannel: """Attempt to get or fetch a channel and return it.""" log.trace(f"Getting the channel {channel_id}.") - channel = client.get_channel(channel_id) + channel = bot.instance.get_channel(channel_id) if not channel: log.debug(f"Channel {channel_id} is not in cache; fetching from API.") - channel = await client.fetch_channel(channel_id) + channel = await bot.instance.fetch_channel(channel_id) log.trace(f"Channel #{channel} ({channel_id}) retrieved.") return channel diff --git a/bot/utils/messages.py b/bot/utils/messages.py index b6c7cab50..42bde358d 100644 --- a/bot/utils/messages.py +++ b/bot/utils/messages.py @@ -10,6 +10,7 @@ import discord from discord.errors import HTTPException from discord.ext.commands import Context +import bot from bot.constants import Emojis, NEGATIVE_REPLIES log = logging.getLogger(__name__) @@ -18,7 +19,6 @@ log = logging.getLogger(__name__) async def wait_for_deletion( message: discord.Message, user_ids: Sequence[discord.abc.Snowflake], - client: discord.Client, deletion_emojis: Sequence[str] = (Emojis.trashcan,), timeout: float = 60 * 5, attach_emojis: bool = True, @@ -49,7 +49,7 @@ async def wait_for_deletion( ) with contextlib.suppress(asyncio.TimeoutError): - await client.wait_for('reaction_add', check=check, timeout=timeout) + await bot.instance.wait_for('reaction_add', check=check, timeout=timeout) await message.delete() diff --git a/bot/utils/services.py b/bot/utils/services.py index 087b9f969..5949c9e48 100644 --- a/bot/utils/services.py +++ b/bot/utils/services.py @@ -1,8 +1,9 @@ import logging from typing import Optional -from aiohttp import ClientConnectorError, ClientSession +from aiohttp import ClientConnectorError +import bot from bot.constants import URLs log = logging.getLogger(__name__) @@ -10,11 +11,10 @@ log = logging.getLogger(__name__) FAILED_REQUEST_ATTEMPTS = 3 -async def send_to_paste_service(http_session: ClientSession, contents: str, *, extension: str = "") -> Optional[str]: +async def send_to_paste_service(contents: str, *, extension: str = "") -> Optional[str]: """ Upload `contents` to the paste service. - `http_session` should be the current running ClientSession from aiohttp `extension` is added to the output URL When an error occurs, `None` is returned, otherwise the generated URL with the suffix. @@ -24,7 +24,7 @@ async def send_to_paste_service(http_session: ClientSession, contents: str, *, e paste_url = URLs.paste_service.format(key="documents") for attempt in range(1, FAILED_REQUEST_ATTEMPTS + 1): try: - async with http_session.post(paste_url, data=contents) as response: + async with bot.instance.http_session.post(paste_url, data=contents) as response: response_json = await response.json() except ClientConnectorError: log.warning( diff --git a/docker-compose.yml b/docker-compose.yml index dc89e8885..0002d1d56 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -18,7 +18,7 @@ services: - "127.0.0.1:6379:6379" snekbox: - image: pythondiscord/snekbox:latest + image: ghcr.io/python-discord/snekbox:latest init: true ipc: none ports: diff --git a/tests/bot/exts/backend/sync/test_base.py b/tests/bot/exts/backend/sync/test_base.py index 4953550f9..3ad9db9c3 100644 --- a/tests/bot/exts/backend/sync/test_base.py +++ b/tests/bot/exts/backend/sync/test_base.py @@ -15,28 +15,21 @@ class TestSyncer(Syncer): _sync = mock.AsyncMock() -class SyncerBaseTests(unittest.TestCase): - """Tests for the syncer base class.""" - - def setUp(self): - self.bot = helpers.MockBot() - - def test_instantiation_fails_without_abstract_methods(self): - """The class must have abstract methods implemented.""" - with self.assertRaisesRegex(TypeError, "Can't instantiate abstract class"): - Syncer(self.bot) - - class SyncerSyncTests(unittest.IsolatedAsyncioTestCase): """Tests for main function orchestrating the sync.""" def setUp(self): - self.bot = helpers.MockBot(user=helpers.MockMember(bot=True)) - self.syncer = TestSyncer(self.bot) + patcher = mock.patch("bot.instance", new=helpers.MockBot(user=helpers.MockMember(bot=True))) + self.bot = patcher.start() + self.addCleanup(patcher.stop) + self.guild = helpers.MockGuild() + TestSyncer._get_diff.reset_mock(return_value=True, side_effect=True) + TestSyncer._sync.reset_mock(return_value=True, side_effect=True) + # Make sure `_get_diff` returns a MagicMock, not an AsyncMock - self.syncer._get_diff.return_value = mock.MagicMock() + TestSyncer._get_diff.return_value = mock.MagicMock() async def test_sync_message_edited(self): """The message should be edited if one was sent, even if the sync has an API error.""" @@ -48,11 +41,11 @@ class SyncerSyncTests(unittest.IsolatedAsyncioTestCase): for message, side_effect, should_edit in subtests: with self.subTest(message=message, side_effect=side_effect, should_edit=should_edit): - self.syncer._sync.side_effect = side_effect + TestSyncer._sync.side_effect = side_effect ctx = helpers.MockContext() ctx.send.return_value = message - await self.syncer.sync(self.guild, ctx) + await TestSyncer.sync(self.guild, ctx) if should_edit: message.edit.assert_called_once() @@ -67,7 +60,7 @@ class SyncerSyncTests(unittest.IsolatedAsyncioTestCase): for ctx, message in subtests: with self.subTest(ctx=ctx, message=message): - await self.syncer.sync(self.guild, ctx) + await TestSyncer.sync(self.guild, ctx) if ctx is not None: ctx.send.assert_called_once() diff --git a/tests/bot/exts/backend/sync/test_cog.py b/tests/bot/exts/backend/sync/test_cog.py index 063a82754..22a07313e 100644 --- a/tests/bot/exts/backend/sync/test_cog.py +++ b/tests/bot/exts/backend/sync/test_cog.py @@ -29,24 +29,24 @@ class SyncCogTestCase(unittest.IsolatedAsyncioTestCase): def setUp(self): self.bot = helpers.MockBot() - self.role_syncer_patcher = mock.patch( + role_syncer_patcher = mock.patch( "bot.exts.backend.sync._syncers.RoleSyncer", autospec=Syncer, spec_set=True ) - self.user_syncer_patcher = mock.patch( + user_syncer_patcher = mock.patch( "bot.exts.backend.sync._syncers.UserSyncer", autospec=Syncer, spec_set=True ) - self.RoleSyncer = self.role_syncer_patcher.start() - self.UserSyncer = self.user_syncer_patcher.start() - self.cog = Sync(self.bot) + self.RoleSyncer = role_syncer_patcher.start() + self.UserSyncer = user_syncer_patcher.start() - def tearDown(self): - self.role_syncer_patcher.stop() - self.user_syncer_patcher.stop() + self.addCleanup(role_syncer_patcher.stop) + self.addCleanup(user_syncer_patcher.stop) + + self.cog = Sync(self.bot) @staticmethod def response_error(status: int) -> ResponseCodeError: @@ -73,8 +73,6 @@ class SyncCogTests(SyncCogTestCase): Sync(self.bot) - self.RoleSyncer.assert_called_once_with(self.bot) - self.UserSyncer.assert_called_once_with(self.bot) sync_guild.assert_called_once_with() self.bot.loop.create_task.assert_called_once_with(mock_sync_guild_coro) @@ -83,8 +81,8 @@ class SyncCogTests(SyncCogTestCase): for guild in (helpers.MockGuild(), None): with self.subTest(guild=guild): self.bot.reset_mock() - self.cog.role_syncer.reset_mock() - self.cog.user_syncer.reset_mock() + self.RoleSyncer.reset_mock() + self.UserSyncer.reset_mock() self.bot.get_guild = mock.MagicMock(return_value=guild) @@ -94,11 +92,11 @@ class SyncCogTests(SyncCogTestCase): self.bot.get_guild.assert_called_once_with(constants.Guild.id) if guild is None: - self.cog.role_syncer.sync.assert_not_called() - self.cog.user_syncer.sync.assert_not_called() + self.RoleSyncer.sync.assert_not_called() + self.UserSyncer.sync.assert_not_called() else: - self.cog.role_syncer.sync.assert_called_once_with(guild) - self.cog.user_syncer.sync.assert_called_once_with(guild) + self.RoleSyncer.sync.assert_called_once_with(guild) + self.UserSyncer.sync.assert_called_once_with(guild) async def patch_user_helper(self, side_effect: BaseException) -> None: """Helper to set a side effect for bot.api_client.patch and then assert it is called.""" @@ -394,14 +392,14 @@ class SyncCogCommandTests(SyncCogTestCase, CommandTestCase): ctx = helpers.MockContext() await self.cog.sync_roles_command(self.cog, ctx) - self.cog.role_syncer.sync.assert_called_once_with(ctx.guild, ctx) + self.RoleSyncer.sync.assert_called_once_with(ctx.guild, ctx) async def test_sync_users_command(self): """sync() should be called on the UserSyncer.""" ctx = helpers.MockContext() await self.cog.sync_users_command(self.cog, ctx) - self.cog.user_syncer.sync.assert_called_once_with(ctx.guild, ctx) + self.UserSyncer.sync.assert_called_once_with(ctx.guild, ctx) async def test_commands_require_admin(self): """The sync commands should only run if the author has the administrator permission.""" diff --git a/tests/bot/exts/backend/sync/test_roles.py b/tests/bot/exts/backend/sync/test_roles.py index 7b9f40cad..541074336 100644 --- a/tests/bot/exts/backend/sync/test_roles.py +++ b/tests/bot/exts/backend/sync/test_roles.py @@ -22,8 +22,9 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase): """Tests for determining differences between roles in the DB and roles in the Guild cache.""" def setUp(self): - self.bot = helpers.MockBot() - self.syncer = RoleSyncer(self.bot) + patcher = mock.patch("bot.instance", new=helpers.MockBot()) + self.bot = patcher.start() + self.addCleanup(patcher.stop) @staticmethod def get_guild(*roles): @@ -44,7 +45,7 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase): self.bot.api_client.get.return_value = [fake_role()] guild = self.get_guild(fake_role()) - actual_diff = await self.syncer._get_diff(guild) + actual_diff = await RoleSyncer._get_diff(guild) expected_diff = (set(), set(), set()) self.assertEqual(actual_diff, expected_diff) @@ -56,7 +57,7 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase): self.bot.api_client.get.return_value = [fake_role(id=41, name="old"), fake_role()] guild = self.get_guild(updated_role, fake_role()) - actual_diff = await self.syncer._get_diff(guild) + actual_diff = await RoleSyncer._get_diff(guild) expected_diff = (set(), {_Role(**updated_role)}, set()) self.assertEqual(actual_diff, expected_diff) @@ -68,7 +69,7 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase): self.bot.api_client.get.return_value = [fake_role()] guild = self.get_guild(fake_role(), new_role) - actual_diff = await self.syncer._get_diff(guild) + actual_diff = await RoleSyncer._get_diff(guild) expected_diff = ({_Role(**new_role)}, set(), set()) self.assertEqual(actual_diff, expected_diff) @@ -80,7 +81,7 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase): self.bot.api_client.get.return_value = [fake_role(), deleted_role] guild = self.get_guild(fake_role()) - actual_diff = await self.syncer._get_diff(guild) + actual_diff = await RoleSyncer._get_diff(guild) expected_diff = (set(), set(), {_Role(**deleted_role)}) self.assertEqual(actual_diff, expected_diff) @@ -98,7 +99,7 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase): ] guild = self.get_guild(fake_role(), new, updated) - actual_diff = await self.syncer._get_diff(guild) + actual_diff = await RoleSyncer._get_diff(guild) expected_diff = ({_Role(**new)}, {_Role(**updated)}, {_Role(**deleted)}) self.assertEqual(actual_diff, expected_diff) @@ -108,8 +109,9 @@ class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase): """Tests for the API requests that sync roles.""" def setUp(self): - self.bot = helpers.MockBot() - self.syncer = RoleSyncer(self.bot) + patcher = mock.patch("bot.instance", new=helpers.MockBot()) + self.bot = patcher.start() + self.addCleanup(patcher.stop) async def test_sync_created_roles(self): """Only POST requests should be made with the correct payload.""" @@ -117,7 +119,7 @@ class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase): role_tuples = {_Role(**role) for role in roles} diff = _Diff(role_tuples, set(), set()) - await self.syncer._sync(diff) + await RoleSyncer._sync(diff) calls = [mock.call("bot/roles", json=role) for role in roles] self.bot.api_client.post.assert_has_calls(calls, any_order=True) @@ -132,7 +134,7 @@ class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase): role_tuples = {_Role(**role) for role in roles} diff = _Diff(set(), role_tuples, set()) - await self.syncer._sync(diff) + await RoleSyncer._sync(diff) calls = [mock.call(f"bot/roles/{role['id']}", json=role) for role in roles] self.bot.api_client.put.assert_has_calls(calls, any_order=True) @@ -147,7 +149,7 @@ class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase): role_tuples = {_Role(**role) for role in roles} diff = _Diff(set(), set(), role_tuples) - await self.syncer._sync(diff) + await RoleSyncer._sync(diff) calls = [mock.call(f"bot/roles/{role['id']}") for role in roles] self.bot.api_client.delete.assert_has_calls(calls, any_order=True) diff --git a/tests/bot/exts/backend/sync/test_users.py b/tests/bot/exts/backend/sync/test_users.py index 9f380a15d..61673e1bb 100644 --- a/tests/bot/exts/backend/sync/test_users.py +++ b/tests/bot/exts/backend/sync/test_users.py @@ -1,4 +1,5 @@ import unittest +from unittest import mock from bot.exts.backend.sync._syncers import UserSyncer, _Diff from tests import helpers @@ -19,8 +20,9 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase): """Tests for determining differences between users in the DB and users in the Guild cache.""" def setUp(self): - self.bot = helpers.MockBot() - self.syncer = UserSyncer(self.bot) + patcher = mock.patch("bot.instance", new=helpers.MockBot()) + self.bot = patcher.start() + self.addCleanup(patcher.stop) @staticmethod def get_guild(*members): @@ -57,7 +59,7 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase): } guild = self.get_guild() - actual_diff = await self.syncer._get_diff(guild) + actual_diff = await UserSyncer._get_diff(guild) expected_diff = ([], [], None) self.assertEqual(actual_diff, expected_diff) @@ -73,7 +75,7 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase): guild = self.get_guild(fake_user()) guild.get_member.return_value = self.get_mock_member(fake_user()) - actual_diff = await self.syncer._get_diff(guild) + actual_diff = await UserSyncer._get_diff(guild) expected_diff = ([], [], None) self.assertEqual(actual_diff, expected_diff) @@ -94,7 +96,7 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase): self.get_mock_member(fake_user()) ] - actual_diff = await self.syncer._get_diff(guild) + actual_diff = await UserSyncer._get_diff(guild) expected_diff = ([], [{"id": 99, "name": "new"}], None) self.assertEqual(actual_diff, expected_diff) @@ -114,7 +116,7 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase): self.get_mock_member(fake_user()), self.get_mock_member(new_user) ] - actual_diff = await self.syncer._get_diff(guild) + actual_diff = await UserSyncer._get_diff(guild) expected_diff = ([new_user], [], None) self.assertEqual(actual_diff, expected_diff) @@ -133,7 +135,7 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase): None ] - actual_diff = await self.syncer._get_diff(guild) + actual_diff = await UserSyncer._get_diff(guild) expected_diff = ([], [{"id": 63, "in_guild": False}], None) self.assertEqual(actual_diff, expected_diff) @@ -157,7 +159,7 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase): None ] - actual_diff = await self.syncer._get_diff(guild) + actual_diff = await UserSyncer._get_diff(guild) expected_diff = ([new_user], [{"id": 55, "name": "updated"}, {"id": 63, "in_guild": False}], None) self.assertEqual(actual_diff, expected_diff) @@ -176,7 +178,7 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase): None ] - actual_diff = await self.syncer._get_diff(guild) + actual_diff = await UserSyncer._get_diff(guild) expected_diff = ([], [], None) self.assertEqual(actual_diff, expected_diff) @@ -186,15 +188,16 @@ class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase): """Tests for the API requests that sync users.""" def setUp(self): - self.bot = helpers.MockBot() - self.syncer = UserSyncer(self.bot) + patcher = mock.patch("bot.instance", new=helpers.MockBot()) + self.bot = patcher.start() + self.addCleanup(patcher.stop) async def test_sync_created_users(self): """Only POST requests should be made with the correct payload.""" users = [fake_user(id=111), fake_user(id=222)] diff = _Diff(users, [], None) - await self.syncer._sync(diff) + await UserSyncer._sync(diff) self.bot.api_client.post.assert_called_once_with("bot/users", json=diff.created) @@ -206,7 +209,7 @@ class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase): users = [fake_user(id=111), fake_user(id=222)] diff = _Diff([], users, None) - await self.syncer._sync(diff) + await UserSyncer._sync(diff) self.bot.api_client.patch.assert_called_once_with("bot/users/bulk_patch", json=diff.updated) diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py index 9a42d0610..321a92445 100644 --- a/tests/bot/exts/utils/test_snekbox.py +++ b/tests/bot/exts/utils/test_snekbox.py @@ -42,9 +42,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): async def test_upload_output(self, mock_paste_util): """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint.""" await self.cog.upload_output("Test output.") - mock_paste_util.assert_called_once_with( - self.bot.http_session, "Test output.", extension="txt" - ) + mock_paste_util.assert_called_once_with("Test output.", extension="txt") def test_prepare_input(self): cases = ( diff --git a/tests/bot/utils/test_services.py b/tests/bot/utils/test_services.py index 5e0855704..1b48f6560 100644 --- a/tests/bot/utils/test_services.py +++ b/tests/bot/utils/test_services.py @@ -5,11 +5,14 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch from aiohttp import ClientConnectorError from bot.utils.services import FAILED_REQUEST_ATTEMPTS, send_to_paste_service +from tests.helpers import MockBot class PasteTests(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: - self.http_session = MagicMock() + patcher = patch("bot.instance", new=MockBot()) + self.bot = patcher.start() + self.addCleanup(patcher.stop) @patch("bot.utils.services.URLs.paste_service", "https://paste_service.com/{key}") async def test_url_and_sent_contents(self): @@ -17,10 +20,10 @@ class PasteTests(unittest.IsolatedAsyncioTestCase): response = MagicMock( json=AsyncMock(return_value={"key": ""}) ) - self.http_session.post().__aenter__.return_value = response - self.http_session.post.reset_mock() - await send_to_paste_service(self.http_session, "Content") - self.http_session.post.assert_called_once_with("https://paste_service.com/documents", data="Content") + self.bot.http_session.post.return_value.__aenter__.return_value = response + self.bot.http_session.post.reset_mock() + await send_to_paste_service("Content") + self.bot.http_session.post.assert_called_once_with("https://paste_service.com/documents", data="Content") @patch("bot.utils.services.URLs.paste_service", "https://paste_service.com/{key}") async def test_paste_returns_correct_url_on_success(self): @@ -34,41 +37,41 @@ class PasteTests(unittest.IsolatedAsyncioTestCase): response = MagicMock( json=AsyncMock(return_value={"key": key}) ) - self.http_session.post().__aenter__.return_value = response + self.bot.http_session.post.return_value.__aenter__.return_value = response for expected_output, extension in test_cases: with self.subTest(msg=f"Send contents with extension {repr(extension)}"): self.assertEqual( - await send_to_paste_service(self.http_session, "", extension=extension), + await send_to_paste_service("", extension=extension), expected_output ) async def test_request_repeated_on_json_errors(self): """Json with error message and invalid json are handled as errors and requests repeated.""" test_cases = ({"message": "error"}, {"unexpected_key": None}, {}) - self.http_session.post().__aenter__.return_value = response = MagicMock() - self.http_session.post.reset_mock() + self.bot.http_session.post.return_value.__aenter__.return_value = response = MagicMock() + self.bot.http_session.post.reset_mock() for error_json in test_cases: with self.subTest(error_json=error_json): response.json = AsyncMock(return_value=error_json) - result = await send_to_paste_service(self.http_session, "") - self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) + result = await send_to_paste_service("") + self.assertEqual(self.bot.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) self.assertIsNone(result) - self.http_session.post.reset_mock() + self.bot.http_session.post.reset_mock() async def test_request_repeated_on_connection_errors(self): """Requests are repeated in the case of connection errors.""" - self.http_session.post = MagicMock(side_effect=ClientConnectorError(Mock(), Mock())) - result = await send_to_paste_service(self.http_session, "") - self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) + self.bot.http_session.post = MagicMock(side_effect=ClientConnectorError(Mock(), Mock())) + result = await send_to_paste_service("") + self.assertEqual(self.bot.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) self.assertIsNone(result) async def test_general_error_handled_and_request_repeated(self): """All `Exception`s are handled, logged and request repeated.""" - self.http_session.post = MagicMock(side_effect=Exception) - result = await send_to_paste_service(self.http_session, "") - self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) + self.bot.http_session.post = MagicMock(side_effect=Exception) + result = await send_to_paste_service("") + self.assertEqual(self.bot.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS) self.assertLogs("bot.utils", logging.ERROR) self.assertIsNone(result) |