aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Leon Sandøy <[email protected]>2020-11-19 02:14:42 +0100
committerGravatar GitHub <[email protected]>2020-11-19 02:14:42 +0100
commit5400642655b4ee84a0e9e1f24e94fcbc9236c105 (patch)
treee96ee7dad3cc7d9b881e52181fc2c62ba364a6d1
parentUpdate snekbox address in config-default.yml (diff)
parentMerge pull request #1290 from python-discord/sebastiaan/backend/improve-actio... (diff)
Merge branch 'master' into kubernetes-deploy
-rw-r--r--.github/workflows/build.yml57
-rw-r--r--.github/workflows/lint-test.yml115
-rw-r--r--.gitignore1
-rw-r--r--Pipfile4
-rw-r--r--README.md8
-rw-r--r--bot/__init__.py67
-rw-r--r--bot/__main__.py78
-rw-r--r--bot/bot.py60
-rw-r--r--bot/constants.py2
-rw-r--r--bot/exts/backend/sync/_cog.py9
-rw-r--r--bot/exts/backend/sync/_syncers.py66
-rw-r--r--bot/exts/help_channels.py11
-rw-r--r--bot/exts/info/codeblock/_cog.py2
-rw-r--r--bot/exts/info/doc.py2
-rw-r--r--bot/exts/info/help.py6
-rw-r--r--bot/exts/info/tags.py2
-rw-r--r--bot/exts/utils/internal.py4
-rw-r--r--bot/exts/utils/snekbox.py4
-rw-r--r--bot/interpreter.py6
-rw-r--r--bot/log.py86
-rw-r--r--bot/utils/channel.py7
-rw-r--r--bot/utils/messages.py4
-rw-r--r--bot/utils/services.py8
-rw-r--r--docker-compose.yml2
-rw-r--r--tests/bot/exts/backend/sync/test_base.py29
-rw-r--r--tests/bot/exts/backend/sync/test_cog.py34
-rw-r--r--tests/bot/exts/backend/sync/test_roles.py26
-rw-r--r--tests/bot/exts/backend/sync/test_users.py29
-rw-r--r--tests/bot/exts/utils/test_snekbox.py4
-rw-r--r--tests/bot/utils/test_services.py39
30 files changed, 486 insertions, 286 deletions
diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
new file mode 100644
index 000000000..706ab462f
--- /dev/null
+++ b/.github/workflows/build.yml
@@ -0,0 +1,57 @@
+name: Build
+
+on:
+ workflow_run:
+ workflows: ["Lint & Test"]
+ branches:
+ - master
+ types:
+ - completed
+
+jobs:
+ build:
+ if: github.event.workflow_run.conclusion == 'success'
+ name: Build & Push
+ runs-on: ubuntu-latest
+
+ steps:
+ # Create a commit SHA-based tag for the container repositories
+ - name: Create SHA Container Tag
+ id: sha_tag
+ run: |
+ tag=$(cut -c 1-7 <<< $GITHUB_SHA)
+ echo "::set-output name=tag::$tag"
+
+ - name: Checkout code
+ uses: actions/checkout@v2
+
+ # The current version (v2) of Docker's build-push action uses
+ # buildx, which comes with BuildKit features that help us speed
+ # up our builds using additional cache features. Buildx also
+ # has a lot of other features that are not as relevant to us.
+ #
+ # See https://github.com/docker/build-push-action
+ - name: Set up Docker Buildx
+ uses: docker/setup-buildx-action@v1
+
+ - name: Login to Github Container Registry
+ uses: docker/login-action@v1
+ with:
+ registry: ghcr.io
+ username: ${{ github.repository_owner }}
+ password: ${{ secrets.GHCR_TOKEN }}
+
+ # Build and push the container to the GitHub Container
+ # Repository. The container will be tagged as "latest"
+ # and with the short SHA of the commit.
+ - name: Build and push
+ uses: docker/build-push-action@v2
+ with:
+ context: .
+ file: ./Dockerfile
+ push: true
+ cache-from: type=registry,ref=ghcr.io/python-discord/bot:latest
+ cache-to: type=inline
+ tags: |
+ ghcr.io/python-discord/bot:latest
+ ghcr.io/python-discord/bot:${{ steps.sha_tag.outputs.tag }}
diff --git a/.github/workflows/lint-test.yml b/.github/workflows/lint-test.yml
new file mode 100644
index 000000000..5444fc3de
--- /dev/null
+++ b/.github/workflows/lint-test.yml
@@ -0,0 +1,115 @@
+name: Lint & Test
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+
+
+jobs:
+ lint-test:
+ runs-on: ubuntu-latest
+ env:
+ # Dummy values for required bot environment variables
+ BOT_API_KEY: foo
+ BOT_SENTRY_DSN: blah
+ BOT_TOKEN: bar
+ REDDIT_CLIENT_ID: spam
+ REDDIT_SECRET: ham
+ REDIS_PASSWORD: ''
+
+ # Configure pip to cache dependencies and do a user install
+ PIP_NO_CACHE_DIR: false
+ PIP_USER: 1
+
+ # Hide the graphical elements from pipenv's output
+ PIPENV_HIDE_EMOJIS: 1
+ PIPENV_NOSPIN: 1
+
+ # Make sure pipenv does not try reuse an environment it's running in
+ PIPENV_IGNORE_VIRTUALENVS: 1
+
+ # Specify explicit paths for python dependencies and the pre-commit
+ # environment so we know which directories to cache
+ PYTHONUSERBASE: ${{ github.workspace }}/.cache/py-user-base
+ PRE_COMMIT_HOME: ${{ github.workspace }}/.cache/pre-commit-cache
+
+ steps:
+ - name: Add custom PYTHONUSERBASE to PATH
+ run: echo '${{ env.PYTHONUSERBASE }}/bin/' >> $GITHUB_PATH
+
+ - name: Checkout repository
+ uses: actions/checkout@v2
+
+ - name: Setup python
+ id: python
+ uses: actions/setup-python@v2
+ with:
+ python-version: '3.8'
+
+ # This step caches our Python dependencies. To make sure we
+ # only restore a cache when the dependencies, the python version,
+ # the runner operating system, and the dependency location haven't
+ # changed, we create a cache key that is a composite of those states.
+ #
+ # Only when the context is exactly the same, we will restore the cache.
+ - name: Python Dependency Caching
+ uses: actions/cache@v2
+ id: python_cache
+ with:
+ path: ${{ env.PYTHONUSERBASE }}
+ key: "python-0-${{ runner.os }}-${{ env.PYTHONUSERBASE }}-\
+ ${{ steps.python.outputs.python-version }}-\
+ ${{ hashFiles('./Pipfile', './Pipfile.lock') }}"
+
+ # Install our dependencies if we did not restore a dependency cache
+ - name: Install dependencies using pipenv
+ if: steps.python_cache.outputs.cache-hit != 'true'
+ run: |
+ pip install pipenv
+ pipenv install --dev --deploy --system
+
+ # This step caches our pre-commit environment. To make sure we
+ # do create a new environment when our pre-commit setup changes,
+ # we create a cache key based on relevant factors.
+ - name: Pre-commit Environment Caching
+ uses: actions/cache@v2
+ with:
+ path: ${{ env.PRE_COMMIT_HOME }}
+ key: "precommit-0-${{ runner.os }}-${{ env.PRE_COMMIT_HOME }}-\
+ ${{ steps.python.outputs.python-version }}-\
+ ${{ hashFiles('./.pre-commit-config.yaml') }}"
+
+ # We will not run `flake8` here, as we will use a separate flake8
+ # action. As pre-commit does not support user installs, we set
+ # PIP_USER=0 to not do a user install.
+ - name: Run pre-commit hooks
+ run: export PIP_USER=0; SKIP=flake8 pre-commit run --all-files
+
+ # Run flake8 and have it format the linting errors in the format of
+ # the GitHub Workflow command to register error annotations. This
+ # means that our flake8 output is automatically added as an error
+ # annotation to both the run result and in the "Files" tab of a
+ # pull request.
+ #
+ # Format used:
+ # ::error file={filename},line={line},col={col}::{message}
+ - name: Run flake8
+ run: "flake8 \
+ --format='::error file=%(path)s,line=%(row)d,col=%(col)d::\
+ [flake8] %(code)s: %(text)s'"
+
+ # We run `coverage` using the `python` command so we can suppress
+ # irrelevant warnings in our CI output.
+ - name: Run tests and generate coverage report
+ run: |
+ python -Wignore -m coverage run -m unittest
+ coverage report -m
+
+ # This step will publish the coverage reports coveralls.io and
+ # print a "job" link in the output of the GitHub Action
+ - name: Publish coverage report to coveralls.io
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ run: coveralls
diff --git a/.gitignore b/.gitignore
index 2074887ad..9186dbe06 100644
--- a/.gitignore
+++ b/.gitignore
@@ -111,6 +111,7 @@ ENV/
# Logfiles
log.*
*.log.*
+!log.py
# Custom user configuration
config.yml
diff --git a/Pipfile b/Pipfile
index 0730b9150..103ce84cf 100644
--- a/Pipfile
+++ b/Pipfile
@@ -48,8 +48,8 @@ python_version = "3.8"
start = "python -m bot"
lint = "pre-commit run --all-files"
precommit = "pre-commit install"
-build = "docker build -t pythondiscord/bot:latest -f Dockerfile ."
-push = "docker push pythondiscord/bot:latest"
+build = "docker build -t ghcr.io/python-discord/bot:latest -f Dockerfile ."
+push = "docker push ghcr.io/python-discord/bot:latest"
test = "coverage run -m unittest"
html = "coverage html"
report = "coverage report"
diff --git a/README.md b/README.md
index 482ada08c..210b3e047 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,8 @@
# Python Utility Bot
[![Discord](https://img.shields.io/static/v1?label=Python%20Discord&logo=discord&message=%3E100k%20members&color=%237289DA&logoColor=white)](https://discord.gg/2B963hn)
-![Lint, Test, Build](https://github.com/python-discord/bot/workflows/Lint,%20Test,%20Build/badge.svg?branch=master)
+[![Lint & Test][1]][2]
+[![Build][3]][4]
[![Coverage Status](https://coveralls.io/repos/github/python-discord/bot/badge.svg)](https://coveralls.io/github/python-discord/bot)
[![License](https://img.shields.io/github/license/python-discord/bot)](LICENSE)
[![Website](https://img.shields.io/badge/website-visit-brightgreen)](https://pythondiscord.com)
@@ -10,3 +11,8 @@ This project is a Discord bot specifically for use with the Python Discord serve
and other tools to help keep the server running like a well-oiled machine.
Read the [Contributing Guide](https://pythondiscord.com/pages/contributing/bot/) on our website if you're interested in helping out.
+
+[1]: https://github.com/python-discord/bot/workflows/Lint%20&%20Test/badge.svg?branch=master
+[2]: https://github.com/python-discord/bot/actions?query=workflow%3A%22Lint+%26+Test%22+branch%3Amaster
+[3]: https://github.com/python-discord/bot/workflows/Build/badge.svg?branch=master
+[4]: https://github.com/python-discord/bot/actions?query=workflow%3ABuild+branch%3Amaster
diff --git a/bot/__init__.py b/bot/__init__.py
index 4fce04532..8f880b8e6 100644
--- a/bot/__init__.py
+++ b/bot/__init__.py
@@ -1,78 +1,25 @@
import asyncio
-import logging
import os
-import sys
from functools import partial, partialmethod
-from logging import Logger, handlers
-from pathlib import Path
+from typing import TYPE_CHECKING
-import coloredlogs
from discord.ext import commands
+from bot import log
from bot.command import Command
-TRACE_LEVEL = logging.TRACE = 5
-logging.addLevelName(TRACE_LEVEL, "TRACE")
-
-
-def monkeypatch_trace(self: logging.Logger, msg: str, *args, **kwargs) -> None:
- """
- Log 'msg % args' with severity 'TRACE'.
-
- To pass exception information, use the keyword argument exc_info with
- a true value, e.g.
-
- logger.trace("Houston, we have an %s", "interesting problem", exc_info=1)
- """
- if self.isEnabledFor(TRACE_LEVEL):
- self._log(TRACE_LEVEL, msg, args, **kwargs)
-
-
-Logger.trace = monkeypatch_trace
-
-DEBUG_MODE = 'local' in os.environ.get("SITE_URL", "local")
-
-log_level = TRACE_LEVEL if DEBUG_MODE else logging.INFO
-format_string = "%(asctime)s | %(name)s | %(levelname)s | %(message)s"
-log_format = logging.Formatter(format_string)
-
-log_file = Path("logs", "bot.log")
-log_file.parent.mkdir(exist_ok=True)
-file_handler = handlers.RotatingFileHandler(log_file, maxBytes=5242880, backupCount=7, encoding="utf8")
-file_handler.setFormatter(log_format)
-
-root_log = logging.getLogger()
-root_log.setLevel(log_level)
-root_log.addHandler(file_handler)
-
-if "COLOREDLOGS_LEVEL_STYLES" not in os.environ:
- coloredlogs.DEFAULT_LEVEL_STYLES = {
- **coloredlogs.DEFAULT_LEVEL_STYLES,
- "trace": {"color": 246},
- "critical": {"background": "red"},
- "debug": coloredlogs.DEFAULT_LEVEL_STYLES["info"]
- }
-
-if "COLOREDLOGS_LOG_FORMAT" not in os.environ:
- coloredlogs.DEFAULT_LOG_FORMAT = format_string
-
-if "COLOREDLOGS_LOG_LEVEL" not in os.environ:
- coloredlogs.DEFAULT_LOG_LEVEL = log_level
-
-coloredlogs.install(logger=root_log, stream=sys.stdout)
-
-logging.getLogger("discord").setLevel(logging.WARNING)
-logging.getLogger("websockets").setLevel(logging.WARNING)
-logging.getLogger("chardet").setLevel(logging.WARNING)
-logging.getLogger("async_rediscache").setLevel(logging.WARNING)
+if TYPE_CHECKING:
+ from bot.bot import Bot
+log.setup()
# On Windows, the selector event loop is required for aiodns.
if os.name == "nt":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
-
# Monkey-patch discord.py decorators to use the Command subclass which supports root aliases.
# Must be patched before any cogs are added.
commands.command = partial(commands.command, cls=Command)
commands.GroupMixin.command = partialmethod(commands.GroupMixin.command, cls=Command)
+
+instance: "Bot" = None # Global Bot instance.
diff --git a/bot/__main__.py b/bot/__main__.py
index 367be1300..257216fa7 100644
--- a/bot/__main__.py
+++ b/bot/__main__.py
@@ -1,76 +1,10 @@
-import asyncio
-import logging
-
-import discord
-import sentry_sdk
-from async_rediscache import RedisSession
-from discord.ext.commands import when_mentioned_or
-from sentry_sdk.integrations.aiohttp import AioHttpIntegration
-from sentry_sdk.integrations.logging import LoggingIntegration
-from sentry_sdk.integrations.redis import RedisIntegration
-
+import bot
from bot import constants
from bot.bot import Bot
-from bot.utils.extensions import EXTENSIONS
-
-# Set up Sentry.
-sentry_logging = LoggingIntegration(
- level=logging.DEBUG,
- event_level=logging.WARNING
-)
-
-sentry_sdk.init(
- dsn=constants.Bot.sentry_dsn,
- integrations=[
- sentry_logging,
- AioHttpIntegration(),
- RedisIntegration(),
- ]
-)
-
-# Create the redis session instance.
-redis_session = RedisSession(
- address=(constants.Redis.host, constants.Redis.port),
- password=constants.Redis.password,
- minsize=1,
- maxsize=20,
- use_fakeredis=constants.Redis.use_fakeredis,
- global_namespace="bot",
-)
-
-# Connect redis session to ensure it's connected before we try to access Redis
-# from somewhere within the bot. We create the event loop in the same way
-# discord.py normally does and pass it to the bot's __init__.
-loop = asyncio.get_event_loop()
-loop.run_until_complete(redis_session.connect())
-
-
-# Instantiate the bot.
-allowed_roles = [discord.Object(id_) for id_ in constants.MODERATION_ROLES]
-intents = discord.Intents().all()
-intents.presences = False
-intents.dm_typing = False
-intents.dm_reactions = False
-intents.invites = False
-intents.webhooks = False
-intents.integrations = False
-bot = Bot(
- redis_session=redis_session,
- loop=loop,
- command_prefix=when_mentioned_or(constants.Bot.prefix),
- activity=discord.Game(name=f"Commands: {constants.Bot.prefix}help"),
- case_insensitive=True,
- max_messages=10_000,
- allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles),
- intents=intents,
-)
-
-# Load extensions.
-extensions = set(EXTENSIONS) # Create a mutable copy.
-if not constants.HelpChannels.enable:
- extensions.remove("bot.exts.help_channels")
+from bot.log import setup_sentry
-for extension in extensions:
- bot.load_extension(extension)
+setup_sentry()
-bot.run(constants.Bot.token)
+bot.instance = Bot.create()
+bot.instance.load_extensions()
+bot.instance.run(constants.Bot.token)
diff --git a/bot/bot.py b/bot/bot.py
index b2e5237fe..36cf7d30a 100644
--- a/bot/bot.py
+++ b/bot/bot.py
@@ -11,7 +11,7 @@ from async_rediscache import RedisSession
from discord.ext import commands
from sentry_sdk import push_scope
-from bot import DEBUG_MODE, api, constants
+from bot import api, constants
from bot.async_stats import AsyncStatsClient
log = logging.getLogger('bot')
@@ -40,7 +40,7 @@ class Bot(commands.Bot):
statsd_url = constants.Stats.statsd_host
- if DEBUG_MODE:
+ if constants.DEBUG_MODE:
# Since statsd is UDP, there are no errors for sending to a down port.
# For this reason, setting the statsd host to 127.0.0.1 for development
# will effectively disable stats.
@@ -95,6 +95,43 @@ class Bot(commands.Bot):
# Build the FilterList cache
self.loop.create_task(self.cache_filter_list_data())
+ @classmethod
+ def create(cls) -> "Bot":
+ """Create and return an instance of a Bot."""
+ loop = asyncio.get_event_loop()
+ allowed_roles = [discord.Object(id_) for id_ in constants.MODERATION_ROLES]
+
+ intents = discord.Intents().all()
+ intents.presences = False
+ intents.dm_typing = False
+ intents.dm_reactions = False
+ intents.invites = False
+ intents.webhooks = False
+ intents.integrations = False
+
+ return cls(
+ redis_session=_create_redis_session(loop),
+ loop=loop,
+ command_prefix=commands.when_mentioned_or(constants.Bot.prefix),
+ activity=discord.Game(name=f"Commands: {constants.Bot.prefix}help"),
+ case_insensitive=True,
+ max_messages=10_000,
+ allowed_mentions=discord.AllowedMentions(everyone=False, roles=allowed_roles),
+ intents=intents,
+ )
+
+ def load_extensions(self) -> None:
+ """Load all enabled extensions."""
+ # Must be done here to avoid a circular import.
+ from bot.utils.extensions import EXTENSIONS
+
+ extensions = set(EXTENSIONS) # Create a mutable copy.
+ if not constants.HelpChannels.enable:
+ extensions.remove("bot.exts.help_channels")
+
+ for extension in extensions:
+ self.load_extension(extension)
+
def add_cog(self, cog: commands.Cog) -> None:
"""Adds a "cog" to the bot and logs the operation."""
super().add_cog(cog)
@@ -243,3 +280,22 @@ class Bot(commands.Bot):
for alias in getattr(command, "root_aliases", ()):
self.all_commands.pop(alias, None)
+
+
+def _create_redis_session(loop: asyncio.AbstractEventLoop) -> RedisSession:
+ """
+ Create and connect to a redis session.
+
+ Ensure the connection is established before returning to prevent race conditions.
+ `loop` is the event loop on which to connect. The Bot should use this same event loop.
+ """
+ redis_session = RedisSession(
+ address=(constants.Redis.host, constants.Redis.port),
+ password=constants.Redis.password,
+ minsize=1,
+ maxsize=20,
+ use_fakeredis=constants.Redis.use_fakeredis,
+ global_namespace="bot",
+ )
+ loop.run_until_complete(redis_session.connect())
+ return redis_session
diff --git a/bot/constants.py b/bot/constants.py
index 731f06fed..2126b2b37 100644
--- a/bot/constants.py
+++ b/bot/constants.py
@@ -632,7 +632,7 @@ class Event(Enum):
# Debug mode
-DEBUG_MODE = True if 'local' in os.environ.get("SITE_URL", "local") else False
+DEBUG_MODE = 'local' in os.environ.get("SITE_URL", "local")
# Paths
BOT_DIR = os.path.dirname(__file__)
diff --git a/bot/exts/backend/sync/_cog.py b/bot/exts/backend/sync/_cog.py
index 6e85e2b7d..48d2b6f02 100644
--- a/bot/exts/backend/sync/_cog.py
+++ b/bot/exts/backend/sync/_cog.py
@@ -18,9 +18,6 @@ class Sync(Cog):
def __init__(self, bot: Bot) -> None:
self.bot = bot
- self.role_syncer = _syncers.RoleSyncer(self.bot)
- self.user_syncer = _syncers.UserSyncer(self.bot)
-
self.bot.loop.create_task(self.sync_guild())
async def sync_guild(self) -> None:
@@ -31,7 +28,7 @@ class Sync(Cog):
if guild is None:
return
- for syncer in (self.role_syncer, self.user_syncer):
+ for syncer in (_syncers.RoleSyncer, _syncers.UserSyncer):
await syncer.sync(guild)
async def patch_user(self, user_id: int, json: Dict[str, Any], ignore_404: bool = False) -> None:
@@ -171,10 +168,10 @@ class Sync(Cog):
@commands.has_permissions(administrator=True)
async def sync_roles_command(self, ctx: Context) -> None:
"""Manually synchronise the guild's roles with the roles on the site."""
- await self.role_syncer.sync(ctx.guild, ctx)
+ await _syncers.RoleSyncer.sync(ctx.guild, ctx)
@sync_group.command(name='users')
@commands.has_permissions(administrator=True)
async def sync_users_command(self, ctx: Context) -> None:
"""Manually synchronise the guild's users with the users on the site."""
- await self.user_syncer.sync(ctx.guild, ctx)
+ await _syncers.UserSyncer.sync(ctx.guild, ctx)
diff --git a/bot/exts/backend/sync/_syncers.py b/bot/exts/backend/sync/_syncers.py
index 38468c2b1..2eb9f9971 100644
--- a/bot/exts/backend/sync/_syncers.py
+++ b/bot/exts/backend/sync/_syncers.py
@@ -6,8 +6,8 @@ from collections import namedtuple
from discord import Guild
from discord.ext.commands import Context
+import bot
from bot.api import ResponseCodeError
-from bot.bot import Bot
log = logging.getLogger(__name__)
@@ -17,57 +17,60 @@ _Role = namedtuple('Role', ('id', 'name', 'colour', 'permissions', 'position'))
_Diff = namedtuple('Diff', ('created', 'updated', 'deleted'))
+# Implementation of static abstract methods are not enforced if the subclass is never instantiated.
+# However, methods are kept abstract to at least symbolise that they should be abstract.
class Syncer(abc.ABC):
"""Base class for synchronising the database with objects in the Discord cache."""
- def __init__(self, bot: Bot) -> None:
- self.bot = bot
-
+ @staticmethod
@property
@abc.abstractmethod
- def name(self) -> str:
+ def name() -> str:
"""The name of the syncer; used in output messages and logging."""
raise NotImplementedError # pragma: no cover
+ @staticmethod
@abc.abstractmethod
- async def _get_diff(self, guild: Guild) -> _Diff:
+ async def _get_diff(guild: Guild) -> _Diff:
"""Return the difference between the cache of `guild` and the database."""
raise NotImplementedError # pragma: no cover
+ @staticmethod
@abc.abstractmethod
- async def _sync(self, diff: _Diff) -> None:
+ async def _sync(diff: _Diff) -> None:
"""Perform the API calls for synchronisation."""
raise NotImplementedError # pragma: no cover
- async def sync(self, guild: Guild, ctx: t.Optional[Context] = None) -> None:
+ @classmethod
+ async def sync(cls, guild: Guild, ctx: t.Optional[Context] = None) -> None:
"""
Synchronise the database with the cache of `guild`.
If `ctx` is given, send a message with the results.
"""
- log.info(f"Starting {self.name} syncer.")
+ log.info(f"Starting {cls.name} syncer.")
if ctx:
- message = await ctx.send(f"📊 Synchronising {self.name}s.")
+ message = await ctx.send(f"📊 Synchronising {cls.name}s.")
else:
message = None
- diff = await self._get_diff(guild)
+ diff = await cls._get_diff(guild)
try:
- await self._sync(diff)
+ await cls._sync(diff)
except ResponseCodeError as e:
- log.exception(f"{self.name} syncer failed!")
+ log.exception(f"{cls.name} syncer failed!")
# Don't show response text because it's probably some really long HTML.
results = f"status {e.status}\n```{e.response_json or 'See log output for details'}```"
- content = f":x: Synchronisation of {self.name}s failed: {results}"
+ content = f":x: Synchronisation of {cls.name}s failed: {results}"
else:
diff_dict = diff._asdict()
results = (f"{name} `{len(val)}`" for name, val in diff_dict.items() if val is not None)
results = ", ".join(results)
- log.info(f"{self.name} syncer finished: {results}.")
- content = f":ok_hand: Synchronisation of {self.name}s complete: {results}"
+ log.info(f"{cls.name} syncer finished: {results}.")
+ content = f":ok_hand: Synchronisation of {cls.name}s complete: {results}"
if message:
await message.edit(content=content)
@@ -78,10 +81,11 @@ class RoleSyncer(Syncer):
name = "role"
- async def _get_diff(self, guild: Guild) -> _Diff:
+ @staticmethod
+ async def _get_diff(guild: Guild) -> _Diff:
"""Return the difference of roles between the cache of `guild` and the database."""
log.trace("Getting the diff for roles.")
- roles = await self.bot.api_client.get('bot/roles')
+ roles = await bot.instance.api_client.get('bot/roles')
# Pack DB roles and guild roles into one common, hashable format.
# They're hashable so that they're easily comparable with sets later.
@@ -110,19 +114,20 @@ class RoleSyncer(Syncer):
return _Diff(roles_to_create, roles_to_update, roles_to_delete)
- async def _sync(self, diff: _Diff) -> None:
+ @staticmethod
+ async def _sync(diff: _Diff) -> None:
"""Synchronise the database with the role cache of `guild`."""
log.trace("Syncing created roles...")
for role in diff.created:
- await self.bot.api_client.post('bot/roles', json=role._asdict())
+ await bot.instance.api_client.post('bot/roles', json=role._asdict())
log.trace("Syncing updated roles...")
for role in diff.updated:
- await self.bot.api_client.put(f'bot/roles/{role.id}', json=role._asdict())
+ await bot.instance.api_client.put(f'bot/roles/{role.id}', json=role._asdict())
log.trace("Syncing deleted roles...")
for role in diff.deleted:
- await self.bot.api_client.delete(f'bot/roles/{role.id}')
+ await bot.instance.api_client.delete(f'bot/roles/{role.id}')
class UserSyncer(Syncer):
@@ -130,7 +135,8 @@ class UserSyncer(Syncer):
name = "user"
- async def _get_diff(self, guild: Guild) -> _Diff:
+ @staticmethod
+ async def _get_diff(guild: Guild) -> _Diff:
"""Return the difference of users between the cache of `guild` and the database."""
log.trace("Getting the diff for users.")
@@ -138,7 +144,7 @@ class UserSyncer(Syncer):
users_to_update = []
seen_guild_users = set()
- async for db_user in self._get_users():
+ async for db_user in UserSyncer._get_users():
# Store user fields which are to be updated.
updated_fields = {}
@@ -185,24 +191,26 @@ class UserSyncer(Syncer):
return _Diff(users_to_create, users_to_update, None)
- async def _get_users(self) -> t.AsyncIterable:
+ @staticmethod
+ async def _get_users() -> t.AsyncIterable:
"""GET users from database."""
query_params = {
"page": 1
}
while query_params["page"]:
- res = await self.bot.api_client.get("bot/users", params=query_params)
+ res = await bot.instance.api_client.get("bot/users", params=query_params)
for user in res["results"]:
yield user
query_params["page"] = res["next_page_no"]
- async def _sync(self, diff: _Diff) -> None:
+ @staticmethod
+ async def _sync(diff: _Diff) -> None:
"""Synchronise the database with the user cache of `guild`."""
log.trace("Syncing created users...")
if diff.created:
- await self.bot.api_client.post("bot/users", json=diff.created)
+ await bot.instance.api_client.post("bot/users", json=diff.created)
log.trace("Syncing updated users...")
if diff.updated:
- await self.bot.api_client.patch("bot/users/bulk_patch", json=diff.updated)
+ await bot.instance.api_client.patch("bot/users/bulk_patch", json=diff.updated)
diff --git a/bot/exts/help_channels.py b/bot/exts/help_channels.py
index 062d4fcfe..f5a8b251b 100644
--- a/bot/exts/help_channels.py
+++ b/bot/exts/help_channels.py
@@ -380,16 +380,13 @@ class HelpChannels(commands.Cog):
try:
self.available_category = await channel_utils.try_get_channel(
- constants.Categories.help_available,
- self.bot
+ constants.Categories.help_available
)
self.in_use_category = await channel_utils.try_get_channel(
- constants.Categories.help_in_use,
- self.bot
+ constants.Categories.help_in_use
)
self.dormant_category = await channel_utils.try_get_channel(
- constants.Categories.help_dormant,
- self.bot
+ constants.Categories.help_dormant
)
except discord.HTTPException:
log.exception("Failed to get a category; cog will be removed")
@@ -500,7 +497,7 @@ class HelpChannels(commands.Cog):
options should be avoided, as it may interfere with the category move we perform.
"""
# Get a fresh copy of the category from the bot to avoid the cache mismatch issue we had.
- category = await channel_utils.try_get_channel(category_id, self.bot)
+ category = await channel_utils.try_get_channel(category_id)
payload = [{"id": c.id, "position": c.position} for c in category.channels]
diff --git a/bot/exts/info/codeblock/_cog.py b/bot/exts/info/codeblock/_cog.py
index 1e0feab0d..9094d9d15 100644
--- a/bot/exts/info/codeblock/_cog.py
+++ b/bot/exts/info/codeblock/_cog.py
@@ -114,7 +114,7 @@ class CodeBlockCog(Cog, name="Code Block"):
bot_message = await message.channel.send(f"Hey {message.author.mention}!", embed=embed)
self.codeblock_message_ids[message.id] = bot_message.id
- self.bot.loop.create_task(wait_for_deletion(bot_message, (message.author.id,), self.bot))
+ self.bot.loop.create_task(wait_for_deletion(bot_message, (message.author.id,)))
# Increase amount of codeblock correction in stats
self.bot.stats.incr("codeblock_corrections")
diff --git a/bot/exts/info/doc.py b/bot/exts/info/doc.py
index 7ec8caa4b..9b5bd6504 100644
--- a/bot/exts/info/doc.py
+++ b/bot/exts/info/doc.py
@@ -365,7 +365,7 @@ class Doc(commands.Cog):
await ctx.message.delete(delay=NOT_FOUND_DELETE_DELAY)
else:
msg = await ctx.send(embed=doc_embed)
- await wait_for_deletion(msg, (ctx.author.id,), client=self.bot)
+ await wait_for_deletion(msg, (ctx.author.id,))
@docs_group.command(name='set', aliases=('s',))
@commands.has_any_role(*MODERATION_ROLES)
diff --git a/bot/exts/info/help.py b/bot/exts/info/help.py
index 599c5d5c0..461ff82fd 100644
--- a/bot/exts/info/help.py
+++ b/bot/exts/info/help.py
@@ -186,7 +186,7 @@ class CustomHelpCommand(HelpCommand):
"""Send help for a single command."""
embed = await self.command_formatting(command)
message = await self.context.send(embed=embed)
- await wait_for_deletion(message, (self.context.author.id,), self.context.bot)
+ await wait_for_deletion(message, (self.context.author.id,))
@staticmethod
def get_commands_brief_details(commands_: List[Command], return_as_list: bool = False) -> Union[List[str], str]:
@@ -225,7 +225,7 @@ class CustomHelpCommand(HelpCommand):
embed.description += f"\n**Subcommands:**\n{command_details}"
message = await self.context.send(embed=embed)
- await wait_for_deletion(message, (self.context.author.id,), self.context.bot)
+ await wait_for_deletion(message, (self.context.author.id,))
async def send_cog_help(self, cog: Cog) -> None:
"""Send help for a cog."""
@@ -241,7 +241,7 @@ class CustomHelpCommand(HelpCommand):
embed.description += f"\n\n**Commands:**\n{command_details}"
message = await self.context.send(embed=embed)
- await wait_for_deletion(message, (self.context.author.id,), self.context.bot)
+ await wait_for_deletion(message, (self.context.author.id,))
@staticmethod
def _category_key(command: Command) -> str:
diff --git a/bot/exts/info/tags.py b/bot/exts/info/tags.py
index ae95ac1ef..8f15f932b 100644
--- a/bot/exts/info/tags.py
+++ b/bot/exts/info/tags.py
@@ -236,7 +236,6 @@ class Tags(Cog):
await wait_for_deletion(
await ctx.send(embed=Embed.from_dict(tag['embed'])),
[ctx.author.id],
- self.bot
)
elif founds and len(tag_name) >= 3:
await wait_for_deletion(
@@ -247,7 +246,6 @@ class Tags(Cog):
)
),
[ctx.author.id],
- self.bot
)
else:
diff --git a/bot/exts/utils/internal.py b/bot/exts/utils/internal.py
index 1b4900f42..3521c8fd4 100644
--- a/bot/exts/utils/internal.py
+++ b/bot/exts/utils/internal.py
@@ -30,7 +30,7 @@ class Internal(Cog):
self.ln = 0
self.stdout = StringIO()
- self.interpreter = Interpreter(bot)
+ self.interpreter = Interpreter()
self.socket_since = datetime.utcnow()
self.socket_event_total = 0
@@ -195,7 +195,7 @@ async def func(): # (None,) -> Any
truncate_index = newline_truncate_index
if len(out) > truncate_index:
- paste_link = await send_to_paste_service(self.bot.http_session, out, extension="py")
+ paste_link = await send_to_paste_service(out, extension="py")
if paste_link is not None:
paste_text = f"full contents at {paste_link}"
else:
diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py
index 41cb00541..9f480c067 100644
--- a/bot/exts/utils/snekbox.py
+++ b/bot/exts/utils/snekbox.py
@@ -70,7 +70,7 @@ class Snekbox(Cog):
if len(output) > MAX_PASTE_LEN:
log.info("Full output is too long to upload")
return "too long to upload"
- return await send_to_paste_service(self.bot.http_session, output, extension="txt")
+ return await send_to_paste_service(output, extension="txt")
@staticmethod
def prepare_input(code: str) -> str:
@@ -219,7 +219,7 @@ class Snekbox(Cog):
response = await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.")
else:
response = await ctx.send(msg)
- self.bot.loop.create_task(wait_for_deletion(response, (ctx.author.id,), ctx.bot))
+ self.bot.loop.create_task(wait_for_deletion(response, (ctx.author.id,)))
log.info(f"{ctx.author}'s job had a return code of {results['returncode']}")
return response
diff --git a/bot/interpreter.py b/bot/interpreter.py
index 8b7268746..b58f7a6b0 100644
--- a/bot/interpreter.py
+++ b/bot/interpreter.py
@@ -4,7 +4,7 @@ from typing import Any
from discord.ext.commands import Context
-from bot.bot import Bot
+import bot
CODE_TEMPLATE = """
async def _func():
@@ -21,8 +21,8 @@ class Interpreter(InteractiveInterpreter):
write_callable = None
- def __init__(self, bot: Bot):
- locals_ = {"bot": bot}
+ def __init__(self):
+ locals_ = {"bot": bot.instance}
super().__init__(locals_)
async def run(self, code: str, ctx: Context, io: StringIO, *args, **kwargs) -> Any:
diff --git a/bot/log.py b/bot/log.py
new file mode 100644
index 000000000..13141de40
--- /dev/null
+++ b/bot/log.py
@@ -0,0 +1,86 @@
+import logging
+import os
+import sys
+from logging import Logger, handlers
+from pathlib import Path
+
+import coloredlogs
+import sentry_sdk
+from sentry_sdk.integrations.aiohttp import AioHttpIntegration
+from sentry_sdk.integrations.logging import LoggingIntegration
+from sentry_sdk.integrations.redis import RedisIntegration
+
+from bot import constants
+
+TRACE_LEVEL = 5
+
+
+def setup() -> None:
+ """Set up loggers."""
+ logging.TRACE = TRACE_LEVEL
+ logging.addLevelName(TRACE_LEVEL, "TRACE")
+ Logger.trace = _monkeypatch_trace
+
+ log_level = TRACE_LEVEL if constants.DEBUG_MODE else logging.INFO
+ format_string = "%(asctime)s | %(name)s | %(levelname)s | %(message)s"
+ log_format = logging.Formatter(format_string)
+
+ log_file = Path("logs", "bot.log")
+ log_file.parent.mkdir(exist_ok=True)
+ file_handler = handlers.RotatingFileHandler(log_file, maxBytes=5242880, backupCount=7, encoding="utf8")
+ file_handler.setFormatter(log_format)
+
+ root_log = logging.getLogger()
+ root_log.setLevel(log_level)
+ root_log.addHandler(file_handler)
+
+ if "COLOREDLOGS_LEVEL_STYLES" not in os.environ:
+ coloredlogs.DEFAULT_LEVEL_STYLES = {
+ **coloredlogs.DEFAULT_LEVEL_STYLES,
+ "trace": {"color": 246},
+ "critical": {"background": "red"},
+ "debug": coloredlogs.DEFAULT_LEVEL_STYLES["info"]
+ }
+
+ if "COLOREDLOGS_LOG_FORMAT" not in os.environ:
+ coloredlogs.DEFAULT_LOG_FORMAT = format_string
+
+ if "COLOREDLOGS_LOG_LEVEL" not in os.environ:
+ coloredlogs.DEFAULT_LOG_LEVEL = log_level
+
+ coloredlogs.install(logger=root_log, stream=sys.stdout)
+
+ logging.getLogger("discord").setLevel(logging.WARNING)
+ logging.getLogger("websockets").setLevel(logging.WARNING)
+ logging.getLogger("chardet").setLevel(logging.WARNING)
+ logging.getLogger("async_rediscache").setLevel(logging.WARNING)
+
+
+def setup_sentry() -> None:
+ """Set up the Sentry logging integrations."""
+ sentry_logging = LoggingIntegration(
+ level=logging.DEBUG,
+ event_level=logging.WARNING
+ )
+
+ sentry_sdk.init(
+ dsn=constants.Bot.sentry_dsn,
+ integrations=[
+ sentry_logging,
+ AioHttpIntegration(),
+ RedisIntegration(),
+ ]
+ )
+
+
+def _monkeypatch_trace(self: logging.Logger, msg: str, *args, **kwargs) -> None:
+ """
+ Log 'msg % args' with severity 'TRACE'.
+
+ To pass exception information, use the keyword argument exc_info with
+ a true value, e.g.
+
+ logger.trace("Houston, we have an %s", "interesting problem", exc_info=1)
+ """
+ if self.isEnabledFor(TRACE_LEVEL):
+ self._log(TRACE_LEVEL, msg, args, **kwargs)
diff --git a/bot/utils/channel.py b/bot/utils/channel.py
index 6bf70bfde..0c072184c 100644
--- a/bot/utils/channel.py
+++ b/bot/utils/channel.py
@@ -2,6 +2,7 @@ import logging
import discord
+import bot
from bot import constants
from bot.constants import Categories
@@ -36,14 +37,14 @@ def is_in_category(channel: discord.TextChannel, category_id: int) -> bool:
return getattr(channel, "category_id", None) == category_id
-async def try_get_channel(channel_id: int, client: discord.Client) -> discord.abc.GuildChannel:
+async def try_get_channel(channel_id: int) -> discord.abc.GuildChannel:
"""Attempt to get or fetch a channel and return it."""
log.trace(f"Getting the channel {channel_id}.")
- channel = client.get_channel(channel_id)
+ channel = bot.instance.get_channel(channel_id)
if not channel:
log.debug(f"Channel {channel_id} is not in cache; fetching from API.")
- channel = await client.fetch_channel(channel_id)
+ channel = await bot.instance.fetch_channel(channel_id)
log.trace(f"Channel #{channel} ({channel_id}) retrieved.")
return channel
diff --git a/bot/utils/messages.py b/bot/utils/messages.py
index b6c7cab50..42bde358d 100644
--- a/bot/utils/messages.py
+++ b/bot/utils/messages.py
@@ -10,6 +10,7 @@ import discord
from discord.errors import HTTPException
from discord.ext.commands import Context
+import bot
from bot.constants import Emojis, NEGATIVE_REPLIES
log = logging.getLogger(__name__)
@@ -18,7 +19,6 @@ log = logging.getLogger(__name__)
async def wait_for_deletion(
message: discord.Message,
user_ids: Sequence[discord.abc.Snowflake],
- client: discord.Client,
deletion_emojis: Sequence[str] = (Emojis.trashcan,),
timeout: float = 60 * 5,
attach_emojis: bool = True,
@@ -49,7 +49,7 @@ async def wait_for_deletion(
)
with contextlib.suppress(asyncio.TimeoutError):
- await client.wait_for('reaction_add', check=check, timeout=timeout)
+ await bot.instance.wait_for('reaction_add', check=check, timeout=timeout)
await message.delete()
diff --git a/bot/utils/services.py b/bot/utils/services.py
index 087b9f969..5949c9e48 100644
--- a/bot/utils/services.py
+++ b/bot/utils/services.py
@@ -1,8 +1,9 @@
import logging
from typing import Optional
-from aiohttp import ClientConnectorError, ClientSession
+from aiohttp import ClientConnectorError
+import bot
from bot.constants import URLs
log = logging.getLogger(__name__)
@@ -10,11 +11,10 @@ log = logging.getLogger(__name__)
FAILED_REQUEST_ATTEMPTS = 3
-async def send_to_paste_service(http_session: ClientSession, contents: str, *, extension: str = "") -> Optional[str]:
+async def send_to_paste_service(contents: str, *, extension: str = "") -> Optional[str]:
"""
Upload `contents` to the paste service.
- `http_session` should be the current running ClientSession from aiohttp
`extension` is added to the output URL
When an error occurs, `None` is returned, otherwise the generated URL with the suffix.
@@ -24,7 +24,7 @@ async def send_to_paste_service(http_session: ClientSession, contents: str, *, e
paste_url = URLs.paste_service.format(key="documents")
for attempt in range(1, FAILED_REQUEST_ATTEMPTS + 1):
try:
- async with http_session.post(paste_url, data=contents) as response:
+ async with bot.instance.http_session.post(paste_url, data=contents) as response:
response_json = await response.json()
except ClientConnectorError:
log.warning(
diff --git a/docker-compose.yml b/docker-compose.yml
index dc89e8885..0002d1d56 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -18,7 +18,7 @@ services:
- "127.0.0.1:6379:6379"
snekbox:
- image: pythondiscord/snekbox:latest
+ image: ghcr.io/python-discord/snekbox:latest
init: true
ipc: none
ports:
diff --git a/tests/bot/exts/backend/sync/test_base.py b/tests/bot/exts/backend/sync/test_base.py
index 4953550f9..3ad9db9c3 100644
--- a/tests/bot/exts/backend/sync/test_base.py
+++ b/tests/bot/exts/backend/sync/test_base.py
@@ -15,28 +15,21 @@ class TestSyncer(Syncer):
_sync = mock.AsyncMock()
-class SyncerBaseTests(unittest.TestCase):
- """Tests for the syncer base class."""
-
- def setUp(self):
- self.bot = helpers.MockBot()
-
- def test_instantiation_fails_without_abstract_methods(self):
- """The class must have abstract methods implemented."""
- with self.assertRaisesRegex(TypeError, "Can't instantiate abstract class"):
- Syncer(self.bot)
-
-
class SyncerSyncTests(unittest.IsolatedAsyncioTestCase):
"""Tests for main function orchestrating the sync."""
def setUp(self):
- self.bot = helpers.MockBot(user=helpers.MockMember(bot=True))
- self.syncer = TestSyncer(self.bot)
+ patcher = mock.patch("bot.instance", new=helpers.MockBot(user=helpers.MockMember(bot=True)))
+ self.bot = patcher.start()
+ self.addCleanup(patcher.stop)
+
self.guild = helpers.MockGuild()
+ TestSyncer._get_diff.reset_mock(return_value=True, side_effect=True)
+ TestSyncer._sync.reset_mock(return_value=True, side_effect=True)
+
# Make sure `_get_diff` returns a MagicMock, not an AsyncMock
- self.syncer._get_diff.return_value = mock.MagicMock()
+ TestSyncer._get_diff.return_value = mock.MagicMock()
async def test_sync_message_edited(self):
"""The message should be edited if one was sent, even if the sync has an API error."""
@@ -48,11 +41,11 @@ class SyncerSyncTests(unittest.IsolatedAsyncioTestCase):
for message, side_effect, should_edit in subtests:
with self.subTest(message=message, side_effect=side_effect, should_edit=should_edit):
- self.syncer._sync.side_effect = side_effect
+ TestSyncer._sync.side_effect = side_effect
ctx = helpers.MockContext()
ctx.send.return_value = message
- await self.syncer.sync(self.guild, ctx)
+ await TestSyncer.sync(self.guild, ctx)
if should_edit:
message.edit.assert_called_once()
@@ -67,7 +60,7 @@ class SyncerSyncTests(unittest.IsolatedAsyncioTestCase):
for ctx, message in subtests:
with self.subTest(ctx=ctx, message=message):
- await self.syncer.sync(self.guild, ctx)
+ await TestSyncer.sync(self.guild, ctx)
if ctx is not None:
ctx.send.assert_called_once()
diff --git a/tests/bot/exts/backend/sync/test_cog.py b/tests/bot/exts/backend/sync/test_cog.py
index 063a82754..22a07313e 100644
--- a/tests/bot/exts/backend/sync/test_cog.py
+++ b/tests/bot/exts/backend/sync/test_cog.py
@@ -29,24 +29,24 @@ class SyncCogTestCase(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.bot = helpers.MockBot()
- self.role_syncer_patcher = mock.patch(
+ role_syncer_patcher = mock.patch(
"bot.exts.backend.sync._syncers.RoleSyncer",
autospec=Syncer,
spec_set=True
)
- self.user_syncer_patcher = mock.patch(
+ user_syncer_patcher = mock.patch(
"bot.exts.backend.sync._syncers.UserSyncer",
autospec=Syncer,
spec_set=True
)
- self.RoleSyncer = self.role_syncer_patcher.start()
- self.UserSyncer = self.user_syncer_patcher.start()
- self.cog = Sync(self.bot)
+ self.RoleSyncer = role_syncer_patcher.start()
+ self.UserSyncer = user_syncer_patcher.start()
- def tearDown(self):
- self.role_syncer_patcher.stop()
- self.user_syncer_patcher.stop()
+ self.addCleanup(role_syncer_patcher.stop)
+ self.addCleanup(user_syncer_patcher.stop)
+
+ self.cog = Sync(self.bot)
@staticmethod
def response_error(status: int) -> ResponseCodeError:
@@ -73,8 +73,6 @@ class SyncCogTests(SyncCogTestCase):
Sync(self.bot)
- self.RoleSyncer.assert_called_once_with(self.bot)
- self.UserSyncer.assert_called_once_with(self.bot)
sync_guild.assert_called_once_with()
self.bot.loop.create_task.assert_called_once_with(mock_sync_guild_coro)
@@ -83,8 +81,8 @@ class SyncCogTests(SyncCogTestCase):
for guild in (helpers.MockGuild(), None):
with self.subTest(guild=guild):
self.bot.reset_mock()
- self.cog.role_syncer.reset_mock()
- self.cog.user_syncer.reset_mock()
+ self.RoleSyncer.reset_mock()
+ self.UserSyncer.reset_mock()
self.bot.get_guild = mock.MagicMock(return_value=guild)
@@ -94,11 +92,11 @@ class SyncCogTests(SyncCogTestCase):
self.bot.get_guild.assert_called_once_with(constants.Guild.id)
if guild is None:
- self.cog.role_syncer.sync.assert_not_called()
- self.cog.user_syncer.sync.assert_not_called()
+ self.RoleSyncer.sync.assert_not_called()
+ self.UserSyncer.sync.assert_not_called()
else:
- self.cog.role_syncer.sync.assert_called_once_with(guild)
- self.cog.user_syncer.sync.assert_called_once_with(guild)
+ self.RoleSyncer.sync.assert_called_once_with(guild)
+ self.UserSyncer.sync.assert_called_once_with(guild)
async def patch_user_helper(self, side_effect: BaseException) -> None:
"""Helper to set a side effect for bot.api_client.patch and then assert it is called."""
@@ -394,14 +392,14 @@ class SyncCogCommandTests(SyncCogTestCase, CommandTestCase):
ctx = helpers.MockContext()
await self.cog.sync_roles_command(self.cog, ctx)
- self.cog.role_syncer.sync.assert_called_once_with(ctx.guild, ctx)
+ self.RoleSyncer.sync.assert_called_once_with(ctx.guild, ctx)
async def test_sync_users_command(self):
"""sync() should be called on the UserSyncer."""
ctx = helpers.MockContext()
await self.cog.sync_users_command(self.cog, ctx)
- self.cog.user_syncer.sync.assert_called_once_with(ctx.guild, ctx)
+ self.UserSyncer.sync.assert_called_once_with(ctx.guild, ctx)
async def test_commands_require_admin(self):
"""The sync commands should only run if the author has the administrator permission."""
diff --git a/tests/bot/exts/backend/sync/test_roles.py b/tests/bot/exts/backend/sync/test_roles.py
index 7b9f40cad..541074336 100644
--- a/tests/bot/exts/backend/sync/test_roles.py
+++ b/tests/bot/exts/backend/sync/test_roles.py
@@ -22,8 +22,9 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
"""Tests for determining differences between roles in the DB and roles in the Guild cache."""
def setUp(self):
- self.bot = helpers.MockBot()
- self.syncer = RoleSyncer(self.bot)
+ patcher = mock.patch("bot.instance", new=helpers.MockBot())
+ self.bot = patcher.start()
+ self.addCleanup(patcher.stop)
@staticmethod
def get_guild(*roles):
@@ -44,7 +45,7 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
self.bot.api_client.get.return_value = [fake_role()]
guild = self.get_guild(fake_role())
- actual_diff = await self.syncer._get_diff(guild)
+ actual_diff = await RoleSyncer._get_diff(guild)
expected_diff = (set(), set(), set())
self.assertEqual(actual_diff, expected_diff)
@@ -56,7 +57,7 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
self.bot.api_client.get.return_value = [fake_role(id=41, name="old"), fake_role()]
guild = self.get_guild(updated_role, fake_role())
- actual_diff = await self.syncer._get_diff(guild)
+ actual_diff = await RoleSyncer._get_diff(guild)
expected_diff = (set(), {_Role(**updated_role)}, set())
self.assertEqual(actual_diff, expected_diff)
@@ -68,7 +69,7 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
self.bot.api_client.get.return_value = [fake_role()]
guild = self.get_guild(fake_role(), new_role)
- actual_diff = await self.syncer._get_diff(guild)
+ actual_diff = await RoleSyncer._get_diff(guild)
expected_diff = ({_Role(**new_role)}, set(), set())
self.assertEqual(actual_diff, expected_diff)
@@ -80,7 +81,7 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
self.bot.api_client.get.return_value = [fake_role(), deleted_role]
guild = self.get_guild(fake_role())
- actual_diff = await self.syncer._get_diff(guild)
+ actual_diff = await RoleSyncer._get_diff(guild)
expected_diff = (set(), set(), {_Role(**deleted_role)})
self.assertEqual(actual_diff, expected_diff)
@@ -98,7 +99,7 @@ class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
]
guild = self.get_guild(fake_role(), new, updated)
- actual_diff = await self.syncer._get_diff(guild)
+ actual_diff = await RoleSyncer._get_diff(guild)
expected_diff = ({_Role(**new)}, {_Role(**updated)}, {_Role(**deleted)})
self.assertEqual(actual_diff, expected_diff)
@@ -108,8 +109,9 @@ class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase):
"""Tests for the API requests that sync roles."""
def setUp(self):
- self.bot = helpers.MockBot()
- self.syncer = RoleSyncer(self.bot)
+ patcher = mock.patch("bot.instance", new=helpers.MockBot())
+ self.bot = patcher.start()
+ self.addCleanup(patcher.stop)
async def test_sync_created_roles(self):
"""Only POST requests should be made with the correct payload."""
@@ -117,7 +119,7 @@ class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase):
role_tuples = {_Role(**role) for role in roles}
diff = _Diff(role_tuples, set(), set())
- await self.syncer._sync(diff)
+ await RoleSyncer._sync(diff)
calls = [mock.call("bot/roles", json=role) for role in roles]
self.bot.api_client.post.assert_has_calls(calls, any_order=True)
@@ -132,7 +134,7 @@ class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase):
role_tuples = {_Role(**role) for role in roles}
diff = _Diff(set(), role_tuples, set())
- await self.syncer._sync(diff)
+ await RoleSyncer._sync(diff)
calls = [mock.call(f"bot/roles/{role['id']}", json=role) for role in roles]
self.bot.api_client.put.assert_has_calls(calls, any_order=True)
@@ -147,7 +149,7 @@ class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase):
role_tuples = {_Role(**role) for role in roles}
diff = _Diff(set(), set(), role_tuples)
- await self.syncer._sync(diff)
+ await RoleSyncer._sync(diff)
calls = [mock.call(f"bot/roles/{role['id']}") for role in roles]
self.bot.api_client.delete.assert_has_calls(calls, any_order=True)
diff --git a/tests/bot/exts/backend/sync/test_users.py b/tests/bot/exts/backend/sync/test_users.py
index 9f380a15d..61673e1bb 100644
--- a/tests/bot/exts/backend/sync/test_users.py
+++ b/tests/bot/exts/backend/sync/test_users.py
@@ -1,4 +1,5 @@
import unittest
+from unittest import mock
from bot.exts.backend.sync._syncers import UserSyncer, _Diff
from tests import helpers
@@ -19,8 +20,9 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
"""Tests for determining differences between users in the DB and users in the Guild cache."""
def setUp(self):
- self.bot = helpers.MockBot()
- self.syncer = UserSyncer(self.bot)
+ patcher = mock.patch("bot.instance", new=helpers.MockBot())
+ self.bot = patcher.start()
+ self.addCleanup(patcher.stop)
@staticmethod
def get_guild(*members):
@@ -57,7 +59,7 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
}
guild = self.get_guild()
- actual_diff = await self.syncer._get_diff(guild)
+ actual_diff = await UserSyncer._get_diff(guild)
expected_diff = ([], [], None)
self.assertEqual(actual_diff, expected_diff)
@@ -73,7 +75,7 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
guild = self.get_guild(fake_user())
guild.get_member.return_value = self.get_mock_member(fake_user())
- actual_diff = await self.syncer._get_diff(guild)
+ actual_diff = await UserSyncer._get_diff(guild)
expected_diff = ([], [], None)
self.assertEqual(actual_diff, expected_diff)
@@ -94,7 +96,7 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
self.get_mock_member(fake_user())
]
- actual_diff = await self.syncer._get_diff(guild)
+ actual_diff = await UserSyncer._get_diff(guild)
expected_diff = ([], [{"id": 99, "name": "new"}], None)
self.assertEqual(actual_diff, expected_diff)
@@ -114,7 +116,7 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
self.get_mock_member(fake_user()),
self.get_mock_member(new_user)
]
- actual_diff = await self.syncer._get_diff(guild)
+ actual_diff = await UserSyncer._get_diff(guild)
expected_diff = ([new_user], [], None)
self.assertEqual(actual_diff, expected_diff)
@@ -133,7 +135,7 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
None
]
- actual_diff = await self.syncer._get_diff(guild)
+ actual_diff = await UserSyncer._get_diff(guild)
expected_diff = ([], [{"id": 63, "in_guild": False}], None)
self.assertEqual(actual_diff, expected_diff)
@@ -157,7 +159,7 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
None
]
- actual_diff = await self.syncer._get_diff(guild)
+ actual_diff = await UserSyncer._get_diff(guild)
expected_diff = ([new_user], [{"id": 55, "name": "updated"}, {"id": 63, "in_guild": False}], None)
self.assertEqual(actual_diff, expected_diff)
@@ -176,7 +178,7 @@ class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
None
]
- actual_diff = await self.syncer._get_diff(guild)
+ actual_diff = await UserSyncer._get_diff(guild)
expected_diff = ([], [], None)
self.assertEqual(actual_diff, expected_diff)
@@ -186,15 +188,16 @@ class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase):
"""Tests for the API requests that sync users."""
def setUp(self):
- self.bot = helpers.MockBot()
- self.syncer = UserSyncer(self.bot)
+ patcher = mock.patch("bot.instance", new=helpers.MockBot())
+ self.bot = patcher.start()
+ self.addCleanup(patcher.stop)
async def test_sync_created_users(self):
"""Only POST requests should be made with the correct payload."""
users = [fake_user(id=111), fake_user(id=222)]
diff = _Diff(users, [], None)
- await self.syncer._sync(diff)
+ await UserSyncer._sync(diff)
self.bot.api_client.post.assert_called_once_with("bot/users", json=diff.created)
@@ -206,7 +209,7 @@ class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase):
users = [fake_user(id=111), fake_user(id=222)]
diff = _Diff([], users, None)
- await self.syncer._sync(diff)
+ await UserSyncer._sync(diff)
self.bot.api_client.patch.assert_called_once_with("bot/users/bulk_patch", json=diff.updated)
diff --git a/tests/bot/exts/utils/test_snekbox.py b/tests/bot/exts/utils/test_snekbox.py
index 9a42d0610..321a92445 100644
--- a/tests/bot/exts/utils/test_snekbox.py
+++ b/tests/bot/exts/utils/test_snekbox.py
@@ -42,9 +42,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
async def test_upload_output(self, mock_paste_util):
"""Upload the eval output to the URLs.paste_service.format(key="documents") endpoint."""
await self.cog.upload_output("Test output.")
- mock_paste_util.assert_called_once_with(
- self.bot.http_session, "Test output.", extension="txt"
- )
+ mock_paste_util.assert_called_once_with("Test output.", extension="txt")
def test_prepare_input(self):
cases = (
diff --git a/tests/bot/utils/test_services.py b/tests/bot/utils/test_services.py
index 5e0855704..1b48f6560 100644
--- a/tests/bot/utils/test_services.py
+++ b/tests/bot/utils/test_services.py
@@ -5,11 +5,14 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch
from aiohttp import ClientConnectorError
from bot.utils.services import FAILED_REQUEST_ATTEMPTS, send_to_paste_service
+from tests.helpers import MockBot
class PasteTests(unittest.IsolatedAsyncioTestCase):
def setUp(self) -> None:
- self.http_session = MagicMock()
+ patcher = patch("bot.instance", new=MockBot())
+ self.bot = patcher.start()
+ self.addCleanup(patcher.stop)
@patch("bot.utils.services.URLs.paste_service", "https://paste_service.com/{key}")
async def test_url_and_sent_contents(self):
@@ -17,10 +20,10 @@ class PasteTests(unittest.IsolatedAsyncioTestCase):
response = MagicMock(
json=AsyncMock(return_value={"key": ""})
)
- self.http_session.post().__aenter__.return_value = response
- self.http_session.post.reset_mock()
- await send_to_paste_service(self.http_session, "Content")
- self.http_session.post.assert_called_once_with("https://paste_service.com/documents", data="Content")
+ self.bot.http_session.post.return_value.__aenter__.return_value = response
+ self.bot.http_session.post.reset_mock()
+ await send_to_paste_service("Content")
+ self.bot.http_session.post.assert_called_once_with("https://paste_service.com/documents", data="Content")
@patch("bot.utils.services.URLs.paste_service", "https://paste_service.com/{key}")
async def test_paste_returns_correct_url_on_success(self):
@@ -34,41 +37,41 @@ class PasteTests(unittest.IsolatedAsyncioTestCase):
response = MagicMock(
json=AsyncMock(return_value={"key": key})
)
- self.http_session.post().__aenter__.return_value = response
+ self.bot.http_session.post.return_value.__aenter__.return_value = response
for expected_output, extension in test_cases:
with self.subTest(msg=f"Send contents with extension {repr(extension)}"):
self.assertEqual(
- await send_to_paste_service(self.http_session, "", extension=extension),
+ await send_to_paste_service("", extension=extension),
expected_output
)
async def test_request_repeated_on_json_errors(self):
"""Json with error message and invalid json are handled as errors and requests repeated."""
test_cases = ({"message": "error"}, {"unexpected_key": None}, {})
- self.http_session.post().__aenter__.return_value = response = MagicMock()
- self.http_session.post.reset_mock()
+ self.bot.http_session.post.return_value.__aenter__.return_value = response = MagicMock()
+ self.bot.http_session.post.reset_mock()
for error_json in test_cases:
with self.subTest(error_json=error_json):
response.json = AsyncMock(return_value=error_json)
- result = await send_to_paste_service(self.http_session, "")
- self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS)
+ result = await send_to_paste_service("")
+ self.assertEqual(self.bot.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS)
self.assertIsNone(result)
- self.http_session.post.reset_mock()
+ self.bot.http_session.post.reset_mock()
async def test_request_repeated_on_connection_errors(self):
"""Requests are repeated in the case of connection errors."""
- self.http_session.post = MagicMock(side_effect=ClientConnectorError(Mock(), Mock()))
- result = await send_to_paste_service(self.http_session, "")
- self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS)
+ self.bot.http_session.post = MagicMock(side_effect=ClientConnectorError(Mock(), Mock()))
+ result = await send_to_paste_service("")
+ self.assertEqual(self.bot.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS)
self.assertIsNone(result)
async def test_general_error_handled_and_request_repeated(self):
"""All `Exception`s are handled, logged and request repeated."""
- self.http_session.post = MagicMock(side_effect=Exception)
- result = await send_to_paste_service(self.http_session, "")
- self.assertEqual(self.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS)
+ self.bot.http_session.post = MagicMock(side_effect=Exception)
+ result = await send_to_paste_service("")
+ self.assertEqual(self.bot.http_session.post.call_count, FAILED_REQUEST_ATTEMPTS)
self.assertLogs("bot.utils", logging.ERROR)
self.assertIsNone(result)