aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar MarkKoz <[email protected]>2020-02-28 09:40:55 -0800
committerGravatar MarkKoz <[email protected]>2020-02-28 09:40:55 -0800
commitb86fb16643d0026e8681d2163bd6ce86b092c57a (patch)
treeb37065e33c88b42f5ce73dde98044777728b65f4
parentIgnore NotFound errors inside continue_eval (diff)
parentMerge pull request #757 from python-discord/feat/backend/b131/error-handling (diff)
Merge remote-tracking branch 'origin/master' into eval-enhancements
-rw-r--r--Pipfile4
-rw-r--r--Pipfile.lock121
-rw-r--r--azure-pipelines.yml2
-rw-r--r--bot/__init__.py87
-rw-r--r--bot/__main__.py22
-rw-r--r--bot/api.py119
-rw-r--r--bot/bot.py56
-rw-r--r--bot/cogs/antimalware.py10
-rw-r--r--bot/cogs/antispam.py2
-rw-r--r--bot/cogs/bot.py5
-rw-r--r--bot/cogs/clean.py2
-rw-r--r--bot/cogs/config_verifier.py40
-rw-r--r--bot/cogs/defcon.py20
-rw-r--r--bot/cogs/doc.py2
-rw-r--r--bot/cogs/duck_pond.py2
-rw-r--r--bot/cogs/error_handler.py274
-rw-r--r--bot/cogs/eval.py4
-rw-r--r--bot/cogs/extensions.py2
-rw-r--r--bot/cogs/free.py2
-rw-r--r--bot/cogs/help.py2
-rw-r--r--bot/cogs/information.py115
-rw-r--r--bot/cogs/jams.py6
-rw-r--r--bot/cogs/logging.py4
-rw-r--r--bot/cogs/moderation/infractions.py2
-rw-r--r--bot/cogs/moderation/management.py4
-rw-r--r--bot/cogs/moderation/modlog.py22
-rw-r--r--bot/cogs/moderation/scheduler.py22
-rw-r--r--bot/cogs/moderation/superstarify.py3
-rw-r--r--bot/cogs/off_topic_names.py2
-rw-r--r--bot/cogs/reddit.py13
-rw-r--r--bot/cogs/reminders.py119
-rw-r--r--bot/cogs/snekbox.py4
-rw-r--r--bot/cogs/sync/cog.py100
-rw-r--r--bot/cogs/sync/syncers.py551
-rw-r--r--bot/cogs/tags.py11
-rw-r--r--bot/cogs/utils.py2
-rw-r--r--bot/cogs/verification.py61
-rw-r--r--bot/cogs/watchchannels/watchchannel.py2
-rw-r--r--bot/constants.py66
-rw-r--r--bot/pagination.py42
-rw-r--r--bot/rules/attachments.py2
-rw-r--r--bot/utils/__init__.py12
-rw-r--r--bot/utils/time.py36
-rw-r--r--config-default.yml222
-rw-r--r--docker-compose.yml2
-rw-r--r--tests/base.py34
-rw-r--r--tests/bot/cogs/sync/test_base.py412
-rw-r--r--tests/bot/cogs/sync/test_cog.py395
-rw-r--r--tests/bot/cogs/sync/test_roles.py287
-rw-r--r--tests/bot/cogs/sync/test_users.py241
-rw-r--r--tests/bot/cogs/test_duck_pond.py4
-rw-r--r--tests/bot/cogs/test_information.py24
-rw-r--r--tests/bot/rules/__init__.py76
-rw-r--r--tests/bot/rules/test_attachments.py97
-rw-r--r--tests/bot/rules/test_burst.py56
-rw-r--r--tests/bot/rules/test_burst_shared.py59
-rw-r--r--tests/bot/rules/test_chars.py66
-rw-r--r--tests/bot/rules/test_discord_emojis.py54
-rw-r--r--tests/bot/rules/test_duplicates.py66
-rw-r--r--tests/bot/rules/test_links.py84
-rw-r--r--tests/bot/rules/test_mentions.py90
-rw-r--r--tests/bot/rules/test_newlines.py105
-rw-r--r--tests/bot/rules/test_role_mentions.py57
-rw-r--r--tests/bot/test_api.py64
-rw-r--r--tests/bot/test_utils.py15
-rw-r--r--tests/helpers.py29
-rw-r--r--tox.ini6
67 files changed, 3084 insertions, 1440 deletions
diff --git a/Pipfile b/Pipfile
index 7fd3efae8..400e64c18 100644
--- a/Pipfile
+++ b/Pipfile
@@ -6,7 +6,6 @@ name = "pypi"
[packages]
discord-py = "~=1.3.1"
aiodns = "~=2.0"
-logmatic-python = "~=0.1"
aiohttp = "~=3.5"
sphinx = "~=2.2"
markdownify = "~=0.4"
@@ -19,11 +18,12 @@ deepdiff = "~=4.0"
requests = "~=2.22"
more_itertools = "~=7.2"
urllib3 = ">=1.24.2,<1.25"
+sentry-sdk = "~=0.14"
[dev-packages]
coverage = "~=4.5"
flake8 = "~=3.7"
-flake8-annotations = "~=1.1"
+flake8-annotations = "~=2.0"
flake8-bugbear = "~=19.8"
flake8-docstrings = "~=1.4"
flake8-import-order = "~=0.18"
diff --git a/Pipfile.lock b/Pipfile.lock
index bf8ff47e9..fa29bf995 100644
--- a/Pipfile.lock
+++ b/Pipfile.lock
@@ -1,7 +1,7 @@
{
"_meta": {
"hash": {
- "sha256": "0a0354a8cbd25b19c61b68f928493a445e737dc6447c97f4c4b52fbf72d887ac"
+ "sha256": "c7706a61eb96c06d073898018ea2dbcf5bd3b15d007496e2d60120a65647f31e"
},
"pipfile-spec": 6,
"requires": {
@@ -18,11 +18,11 @@
"default": {
"aio-pika": {
"hashes": [
- "sha256:a5837277e53755078db3a9e8c45bbca605c8ba9ecba7a02d74a7a1779f444723",
- "sha256:fa32e33b4b7d0804dcf439ae6ff24d2f0a83d1ba280ee9f555e647d71d394ff5"
+ "sha256:4199122a450dffd8303b7857a9d82657bf1487fe329e489520833b40fbe92406",
+ "sha256:fe85c7456e5c060bce4eb9cffab5b2c4d3c563cb72177977b3556c54c8e3aeb6"
],
"index": "pypi",
- "version": "==6.4.1"
+ "version": "==6.5.2"
},
"aiodns": {
"hashes": [
@@ -52,10 +52,10 @@
},
"aiormq": {
"hashes": [
- "sha256:8c215a970133ab5ee7c478decac55b209af7731050f52d11439fe910fa0f9e9d",
- "sha256:9210f3389200aee7d8067f6435f4a9eff2d3a30b88beb5eaae406ccc11c0fc01"
+ "sha256:286e0b0772075580466e45f98f051b9728a9316b9c36f0c14c7bc1409be375b0",
+ "sha256:7ed7d6df6b57af7f8bce7d1ebcbdfc32b676192e46703e81e9e217316e56b5bd"
],
- "version": "==3.2.0"
+ "version": "==3.2.1"
},
"alabaster": {
"hashes": [
@@ -164,18 +164,18 @@
},
"fuzzywuzzy": {
"hashes": [
- "sha256:5ac7c0b3f4658d2743aa17da53a55598144edbc5bee3c6863840636e6926f254",
- "sha256:6f49de47db00e1c71d40ad16da42284ac357936fa9b66bea1df63fed07122d62"
+ "sha256:45016e92264780e58972dca1b3d939ac864b78437422beecebb3095f8efd00e8",
+ "sha256:928244b28db720d1e0ee7587acf660ea49d7e4c632569cad4f1cd7e68a5f0993"
],
"index": "pypi",
- "version": "==0.17.0"
+ "version": "==0.18.0"
},
"idna": {
"hashes": [
- "sha256:c357b3f628cf53ae2c4c05627ecc484553142ca23264e593d327bcde5e9c3407",
- "sha256:ea8b7f6188e6fa117537c3df7da9fc686d485087abf6ac197f9c46432f7e4a3c"
+ "sha256:7588d1c14ae4c77d74036e8c22ff447b26d0fde8f007354fd48a7814db15b7cb",
+ "sha256:a068a21ceac8a4d63dbfd964670474107f541babbd2250d61922f029858365fa"
],
- "version": "==2.8"
+ "version": "==2.9"
},
"imagesize": {
"hashes": [
@@ -191,13 +191,6 @@
],
"version": "==2.11.1"
},
- "logmatic-python": {
- "hashes": [
- "sha256:0c15ac9f5faa6a60059b28910db642c3dc7722948c3cc940923f8c9039604342"
- ],
- "index": "pypi",
- "version": "==0.1.7"
- },
"lxml": {
"hashes": [
"sha256:06d4e0bbb1d62e38ae6118406d7cdb4693a3fa34ee3762238bcb96c9e36a93cd",
@@ -388,12 +381,6 @@
"index": "pypi",
"version": "==2.8.1"
},
- "python-json-logger": {
- "hashes": [
- "sha256:b7a31162f2a01965a5efb94453ce69230ed208468b0bbc7fdfc56e6d8df2e281"
- ],
- "version": "==0.1.11"
- },
"pytz": {
"hashes": [
"sha256:1c557d7d0e871de1f5ccd5833f60fb2550652da6be2693c1e02300743d21500d",
@@ -420,11 +407,19 @@
},
"requests": {
"hashes": [
- "sha256:11e007a8a2aa0323f5a921e9e6a2d7e4e67d9877e85773fba9ba6419025cbeb4",
- "sha256:9cf5292fcd0f598c671cfc1e0d7d1a7f13bb8085e9a590f48c010551dc6c4b31"
+ "sha256:43999036bfa82904b6af1d99e4882b560e5e2c68e5c4b0aa03b655f3d7d73fee",
+ "sha256:b3f43d496c6daba4493e7c431722aeb7dbc6288f52a6e04e7b6023b0247817e6"
+ ],
+ "index": "pypi",
+ "version": "==2.23.0"
+ },
+ "sentry-sdk": {
+ "hashes": [
+ "sha256:b06dd27391fd11fb32f84fe054e6a64736c469514a718a99fb5ce1dff95d6b28",
+ "sha256:e023da07cfbead3868e1e2ba994160517885a32dfd994fc455b118e37989479b"
],
"index": "pypi",
- "version": "==2.22.0"
+ "version": "==0.14.1"
},
"six": {
"hashes": [
@@ -449,11 +444,11 @@
},
"sphinx": {
"hashes": [
- "sha256:298537cb3234578b2d954ff18c5608468229e116a9757af3b831c2b2b4819159",
- "sha256:e6e766b74f85f37a5f3e0773a1e1be8db3fcb799deb58ca6d18b70b0b44542a5"
+ "sha256:525527074f2e0c2585f68f73c99b4dc257c34bbe308b27f5f8c7a6e20642742f",
+ "sha256:543d39db5f82d83a5c1aa0c10c88f2b6cff2da3e711aa849b2c627b4b403bbd9"
],
"index": "pypi",
- "version": "==2.3.1"
+ "version": "==2.4.2"
},
"sphinxcontrib-applehelp": {
"hashes": [
@@ -556,6 +551,13 @@
}
},
"develop": {
+ "appdirs": {
+ "hashes": [
+ "sha256:9e5896d1372858f8dd3344faf4e5014d21849c756c8d5701f78f8a103b372d92",
+ "sha256:d8b24664561d0d34ddfaec54636d502d7cea6e29c3eaf68f3df6180863e2166e"
+ ],
+ "version": "==1.4.3"
+ },
"aspy.yaml": {
"hashes": [
"sha256:463372c043f70160a9ec950c3f1e4c3a82db5fca01d334b6bc89c7164d744bdc",
@@ -579,10 +581,10 @@
},
"cfgv": {
"hashes": [
- "sha256:edb387943b665bf9c434f717bf630fa78aecd53d5900d2e05da6ad6048553144",
- "sha256:fbd93c9ab0a523bf7daec408f3be2ed99a980e20b2d19b50fc184ca6b820d289"
+ "sha256:04b093b14ddf9fd4d17c53ebfd55582d27b76ed30050193c14e560770c5360eb",
+ "sha256:f22b426ed59cd2ab2b54ff96608d846c33dfb8766a67f0b4a6ce130ce244414f"
],
- "version": "==2.0.1"
+ "version": "==3.0.0"
},
"chardet": {
"hashes": [
@@ -636,6 +638,12 @@
"index": "pypi",
"version": "==4.5.4"
},
+ "distlib": {
+ "hashes": [
+ "sha256:2e166e231a26b36d6dfe35a48c4464346620f8645ed0ace01ee31822b288de21"
+ ],
+ "version": "==0.3.0"
+ },
"dodgy": {
"hashes": [
"sha256:28323cbfc9352139fdd3d316fa17f325cc0e9ac74438cbba51d70f9b48f86c3a",
@@ -658,6 +666,13 @@
],
"version": "==0.3"
},
+ "filelock": {
+ "hashes": [
+ "sha256:18d82244ee114f543149c66a6e0c14e9c4f8a1044b5cdaadd0f82159d6a6ff59",
+ "sha256:929b7d63ec5b7d6b71b0fa5ac14e030b3f70b75747cef1b10da9b879fef15836"
+ ],
+ "version": "==3.0.12"
+ },
"flake8": {
"hashes": [
"sha256:45681a117ecc81e870cbf1262835ae4af5e7a8b08e40b944a8a6e6b895914cfb",
@@ -668,11 +683,11 @@
},
"flake8-annotations": {
"hashes": [
- "sha256:05b85538014c850a86dce7374bb6621c64481c24e35e8e90af1315f4d7a3dbaa",
- "sha256:43e5233a76fda002b91a54a7cc4510f099c4bfd6279502ec70164016250eebd1"
+ "sha256:19a6637a5da1bb7ea7948483ca9e2b9e15b213e687e7bf5ff8c1bfc91c185006",
+ "sha256:bb033b72cdd3a2b0a530bbdf2081f12fbea7d70baeaaebb5899723a45f424b8e"
],
"index": "pypi",
- "version": "==1.1.3"
+ "version": "==2.0.0"
},
"flake8-bugbear": {
"hashes": [
@@ -700,11 +715,11 @@
},
"flake8-string-format": {
"hashes": [
- "sha256:68ea72a1a5b75e7018cae44d14f32473c798cf73d75cbaed86c6a9a907b770b2",
- "sha256:774d56103d9242ed968897455ef49b7d6de272000cfa83de5814273a868832f1"
+ "sha256:65f3da786a1461ef77fca3780b314edb2853c377f2e35069723348c8917deaa2",
+ "sha256:812ff431f10576a74c89be4e85b8e075a705be39bc40c4b4278b5b13e2afa9af"
],
"index": "pypi",
- "version": "==0.2.3"
+ "version": "==0.3.0"
},
"flake8-tidy-imports": {
"hashes": [
@@ -730,10 +745,10 @@
},
"idna": {
"hashes": [
- "sha256:c357b3f628cf53ae2c4c05627ecc484553142ca23264e593d327bcde5e9c3407",
- "sha256:ea8b7f6188e6fa117537c3df7da9fc686d485087abf6ac197f9c46432f7e4a3c"
+ "sha256:7588d1c14ae4c77d74036e8c22ff447b26d0fde8f007354fd48a7814db15b7cb",
+ "sha256:a068a21ceac8a4d63dbfd964670474107f541babbd2250d61922f029858365fa"
],
- "version": "==2.8"
+ "version": "==2.9"
},
"importlib-metadata": {
"hashes": [
@@ -818,11 +833,11 @@
},
"requests": {
"hashes": [
- "sha256:11e007a8a2aa0323f5a921e9e6a2d7e4e67d9877e85773fba9ba6419025cbeb4",
- "sha256:9cf5292fcd0f598c671cfc1e0d7d1a7f13bb8085e9a590f48c010551dc6c4b31"
+ "sha256:43999036bfa82904b6af1d99e4882b560e5e2c68e5c4b0aa03b655f3d7d73fee",
+ "sha256:b3f43d496c6daba4493e7c431722aeb7dbc6288f52a6e04e7b6023b0247817e6"
],
"index": "pypi",
- "version": "==2.22.0"
+ "version": "==2.23.0"
},
"safety": {
"hashes": [
@@ -898,17 +913,17 @@
},
"virtualenv": {
"hashes": [
- "sha256:0d62c70883c0342d59c11d0ddac0d954d0431321a41ab20851facf2b222598f3",
- "sha256:55059a7a676e4e19498f1aad09b8313a38fcc0cdbe4fdddc0e9b06946d21b4bb"
+ "sha256:08f3623597ce73b85d6854fb26608a6f39ee9d055c81178dc6583803797f8994",
+ "sha256:de2cbdd5926c48d7b84e0300dea9e8f276f61d186e8e49223d71d91250fbaebd"
],
- "version": "==16.7.9"
+ "version": "==20.0.4"
},
"zipp": {
"hashes": [
- "sha256:ccc94ed0909b58ffe34430ea5451f07bc0c76467d7081619a454bf5c98b89e28",
- "sha256:feae2f18633c32fc71f2de629bfb3bd3c9325cd4419642b1f1da42ee488d9b98"
+ "sha256:12248a63bbdf7548f89cb4c7cda4681e537031eda29c02ea29674bc6854460c2",
+ "sha256:7c0f8e91abc0dc07a5068f315c52cb30c66bfbc581e5b50704c8a2f6ebae794a"
],
- "version": "==2.1.0"
+ "version": "==3.0.0"
}
}
}
diff --git a/azure-pipelines.yml b/azure-pipelines.yml
index 0400ac4d2..874364a6f 100644
--- a/azure-pipelines.yml
+++ b/azure-pipelines.yml
@@ -30,7 +30,7 @@ jobs:
- script: python -m flake8
displayName: 'Run linter'
- - script: BOT_API_KEY=foo BOT_TOKEN=bar WOLFRAM_API_KEY=baz REDDIT_CLIENT_ID=spam REDDIT_SECRET=ham coverage run -m xmlrunner
+ - script: BOT_API_KEY=foo BOT_SENTRY_DSN=blah BOT_TOKEN=bar WOLFRAM_API_KEY=baz REDDIT_CLIENT_ID=spam REDDIT_SECRET=ham coverage run -m xmlrunner
displayName: Run tests
- script: coverage report -m && coverage xml -o coverage.xml
diff --git a/bot/__init__.py b/bot/__init__.py
index 789ace5c0..f7a410706 100644
--- a/bot/__init__.py
+++ b/bot/__init__.py
@@ -4,11 +4,8 @@ import sys
from logging import Logger, StreamHandler, handlers
from pathlib import Path
-from logmatic import JsonFormatter
-
-
-logging.TRACE = 5
-logging.addLevelName(logging.TRACE, "TRACE")
+TRACE_LEVEL = logging.TRACE = 5
+logging.addLevelName(TRACE_LEVEL, "TRACE")
def monkeypatch_trace(self: logging.Logger, msg: str, *args, **kwargs) -> None:
@@ -20,75 +17,29 @@ def monkeypatch_trace(self: logging.Logger, msg: str, *args, **kwargs) -> None:
logger.trace("Houston, we have an %s", "interesting problem", exc_info=1)
"""
- if self.isEnabledFor(logging.TRACE):
- self._log(logging.TRACE, msg, args, **kwargs)
+ if self.isEnabledFor(TRACE_LEVEL):
+ self._log(TRACE_LEVEL, msg, args, **kwargs)
Logger.trace = monkeypatch_trace
-# Set up logging
-logging_handlers = []
-
-# We can't import this yet, so we have to define it ourselves
-DEBUG_MODE = True if 'local' in os.environ.get("SITE_URL", "local") else False
-
-LOG_DIR = Path("logs")
-LOG_DIR.mkdir(exist_ok=True)
-
-if DEBUG_MODE:
- logging_handlers.append(StreamHandler(stream=sys.stdout))
-
- json_handler = logging.FileHandler(filename=Path(LOG_DIR, "log.json"), mode="w")
- json_handler.formatter = JsonFormatter()
- logging_handlers.append(json_handler)
-else:
-
- logfile = Path(LOG_DIR, "bot.log")
- megabyte = 1048576
-
- filehandler = handlers.RotatingFileHandler(logfile, maxBytes=(megabyte*5), backupCount=7)
- logging_handlers.append(filehandler)
-
- json_handler = logging.StreamHandler(stream=sys.stdout)
- json_handler.formatter = JsonFormatter()
- logging_handlers.append(json_handler)
-
-
-logging.basicConfig(
- format="%(asctime)s Bot: | %(name)33s | %(levelname)8s | %(message)s",
- datefmt="%b %d %H:%M:%S",
- level=logging.TRACE if DEBUG_MODE else logging.INFO,
- handlers=logging_handlers
-)
-
-log = logging.getLogger(__name__)
-
-
-for key, value in logging.Logger.manager.loggerDict.items():
- # Force all existing loggers to the correct level and handlers
- # This happens long before we instantiate our loggers, so
- # those should still have the expected level
-
- if key == "bot":
- continue
-
- if not isinstance(value, logging.Logger):
- # There might be some logging.PlaceHolder objects in there
- continue
+DEBUG_MODE = 'local' in os.environ.get("SITE_URL", "local")
- if DEBUG_MODE:
- value.setLevel(logging.DEBUG)
- else:
- value.setLevel(logging.INFO)
+log_format = logging.Formatter("%(asctime)s | %(name)s | %(levelname)s | %(message)s")
- for handler in value.handlers.copy():
- value.removeHandler(handler)
+stream_handler = StreamHandler(stream=sys.stdout)
+stream_handler.setFormatter(log_format)
- for handler in logging_handlers:
- value.addHandler(handler)
+log_file = Path("logs", "bot.log")
+log_file.parent.mkdir(exist_ok=True)
+file_handler = handlers.RotatingFileHandler(log_file, maxBytes=5242880, backupCount=7)
+file_handler.setFormatter(log_format)
+root_log = logging.getLogger()
+root_log.setLevel(TRACE_LEVEL if DEBUG_MODE else logging.INFO)
+root_log.addHandler(stream_handler)
+root_log.addHandler(file_handler)
-# Silence irrelevant loggers
-logging.getLogger("aio_pika").setLevel(logging.ERROR)
-logging.getLogger("discord").setLevel(logging.ERROR)
-logging.getLogger("websockets").setLevel(logging.ERROR)
+logging.getLogger("discord").setLevel(logging.WARNING)
+logging.getLogger("websockets").setLevel(logging.WARNING)
+logging.getLogger(__name__)
diff --git a/bot/__main__.py b/bot/__main__.py
index 84bc7094b..3df477a6d 100644
--- a/bot/__main__.py
+++ b/bot/__main__.py
@@ -1,10 +1,23 @@
+import logging
+
import discord
+import sentry_sdk
from discord.ext.commands import when_mentioned_or
+from sentry_sdk.integrations.logging import LoggingIntegration
from bot import patches
from bot.bot import Bot
-from bot.constants import Bot as BotConfig, DEBUG_MODE
+from bot.constants import Bot as BotConfig
+
+sentry_logging = LoggingIntegration(
+ level=logging.DEBUG,
+ event_level=logging.WARNING
+)
+sentry_sdk.init(
+ dsn=BotConfig.sentry_dsn,
+ integrations=[sentry_logging]
+)
bot = Bot(
command_prefix=when_mentioned_or(BotConfig.prefix),
@@ -18,6 +31,7 @@ bot.load_extension("bot.cogs.error_handler")
bot.load_extension("bot.cogs.filtering")
bot.load_extension("bot.cogs.logging")
bot.load_extension("bot.cogs.security")
+bot.load_extension("bot.cogs.config_verifier")
# Commands, etc
bot.load_extension("bot.cogs.antimalware")
@@ -27,10 +41,8 @@ bot.load_extension("bot.cogs.clean")
bot.load_extension("bot.cogs.extensions")
bot.load_extension("bot.cogs.help")
-# Only load this in production
-if not DEBUG_MODE:
- bot.load_extension("bot.cogs.doc")
- bot.load_extension("bot.cogs.verification")
+bot.load_extension("bot.cogs.doc")
+bot.load_extension("bot.cogs.verification")
# Feature cogs
bot.load_extension("bot.cogs.alias")
diff --git a/bot/api.py b/bot/api.py
index 56db99828..e59916114 100644
--- a/bot/api.py
+++ b/bot/api.py
@@ -32,6 +32,11 @@ class ResponseCodeError(ValueError):
class APIClient:
"""Django Site API wrapper."""
+ # These are class attributes so they can be seen when being mocked for tests.
+ # See commit 22a55534ef13990815a6f69d361e2a12693075d5 for details.
+ session: Optional[aiohttp.ClientSession] = None
+ loop: asyncio.AbstractEventLoop = None
+
def __init__(self, loop: asyncio.AbstractEventLoop, **kwargs):
auth_headers = {
'Authorization': f"Token {Keys.site_api}"
@@ -42,7 +47,7 @@ class APIClient:
else:
kwargs['headers'] = auth_headers
- self.session: Optional[aiohttp.ClientSession] = None
+ self.session = None
self.loop = loop
self._ready = asyncio.Event(loop=loop)
@@ -85,43 +90,35 @@ class APIClient:
response_text = await response.text()
raise ResponseCodeError(response=response, response_text=response_text)
- async def get(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict:
- """Site API GET."""
+ async def request(self, method: str, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict:
+ """Send an HTTP request to the site API and return the JSON response."""
await self._ready.wait()
- async with self.session.get(self._url_for(endpoint), *args, **kwargs) as resp:
+ async with self.session.request(method.upper(), self._url_for(endpoint), **kwargs) as resp:
await self.maybe_raise_for_status(resp, raise_for_status)
return await resp.json()
- async def patch(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict:
- """Site API PATCH."""
- await self._ready.wait()
+ async def get(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict:
+ """Site API GET."""
+ return await self.request("GET", endpoint, raise_for_status=raise_for_status, **kwargs)
- async with self.session.patch(self._url_for(endpoint), *args, **kwargs) as resp:
- await self.maybe_raise_for_status(resp, raise_for_status)
- return await resp.json()
+ async def patch(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict:
+ """Site API PATCH."""
+ return await self.request("PATCH", endpoint, raise_for_status=raise_for_status, **kwargs)
- async def post(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict:
+ async def post(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict:
"""Site API POST."""
- await self._ready.wait()
+ return await self.request("POST", endpoint, raise_for_status=raise_for_status, **kwargs)
- async with self.session.post(self._url_for(endpoint), *args, **kwargs) as resp:
- await self.maybe_raise_for_status(resp, raise_for_status)
- return await resp.json()
-
- async def put(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict:
+ async def put(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict:
"""Site API PUT."""
- await self._ready.wait()
+ return await self.request("PUT", endpoint, raise_for_status=raise_for_status, **kwargs)
- async with self.session.put(self._url_for(endpoint), *args, **kwargs) as resp:
- await self.maybe_raise_for_status(resp, raise_for_status)
- return await resp.json()
-
- async def delete(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> Optional[dict]:
+ async def delete(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> Optional[dict]:
"""Site API DELETE."""
await self._ready.wait()
- async with self.session.delete(self._url_for(endpoint), *args, **kwargs) as resp:
+ async with self.session.delete(self._url_for(endpoint), **kwargs) as resp:
if resp.status == 204:
return None
@@ -141,77 +138,3 @@ def loop_is_running() -> bool:
except RuntimeError:
return False
return True
-
-
-class APILoggingHandler(logging.StreamHandler):
- """Site API logging handler."""
-
- def __init__(self, client: APIClient):
- logging.StreamHandler.__init__(self)
- self.client = client
-
- # internal batch of shipoff tasks that must not be scheduled
- # on the event loop yet - scheduled when the event loop is ready.
- self.queue = []
-
- async def ship_off(self, payload: dict) -> None:
- """Ship log payload to the logging API."""
- try:
- await self.client.post('logs', json=payload)
- except ResponseCodeError as err:
- log.warning(
- "Cannot send logging record to the site, got code %d.",
- err.response.status,
- extra={'via_handler': True}
- )
- except Exception as err:
- log.warning(
- "Cannot send logging record to the site: %r",
- err,
- extra={'via_handler': True}
- )
-
- def emit(self, record: logging.LogRecord) -> None:
- """
- Determine if a log record should be shipped to the logging API.
-
- If the asyncio event loop is not yet running, log records will instead be put in a queue
- which will be consumed once the event loop is running.
-
- The following two conditions are set:
- 1. Do not log anything below DEBUG (only applies to the monkeypatched `TRACE` level)
- 2. Ignore log records originating from this logging handler itself to prevent infinite recursion
- """
- if (
- record.levelno >= logging.DEBUG
- and not record.__dict__.get('via_handler')
- ):
- payload = {
- 'application': 'bot',
- 'logger_name': record.name,
- 'level': record.levelname.lower(),
- 'module': record.module,
- 'line': record.lineno,
- 'message': self.format(record)
- }
-
- task = self.ship_off(payload)
- if not loop_is_running():
- self.queue.append(task)
- else:
- asyncio.create_task(task)
- self.schedule_queued_tasks()
-
- def schedule_queued_tasks(self) -> None:
- """Consume the queue and schedule the logging of each queued record."""
- for task in self.queue:
- asyncio.create_task(task)
-
- if self.queue:
- log.debug(
- "Scheduled %d pending logging tasks.",
- len(self.queue),
- extra={'via_handler': True}
- )
-
- self.queue.clear()
diff --git a/bot/bot.py b/bot/bot.py
index 8f808272f..19b9035c4 100644
--- a/bot/bot.py
+++ b/bot/bot.py
@@ -1,11 +1,14 @@
+import asyncio
import logging
import socket
from typing import Optional
import aiohttp
+import discord
from discord.ext import commands
from bot import api
+from bot import constants
log = logging.getLogger('bot')
@@ -17,17 +20,17 @@ class Bot(commands.Bot):
# Use asyncio for DNS resolution instead of threads so threads aren't spammed.
# Use AF_INET as its socket family to prevent HTTPS related problems both locally
# and in production.
- self.connector = aiohttp.TCPConnector(
+ self._connector = aiohttp.TCPConnector(
resolver=aiohttp.AsyncResolver(),
family=socket.AF_INET,
)
- super().__init__(*args, connector=self.connector, **kwargs)
+ super().__init__(*args, connector=self._connector, **kwargs)
- self.http_session: Optional[aiohttp.ClientSession] = None
- self.api_client = api.APIClient(loop=self.loop, connector=self.connector)
+ self._guild_available = asyncio.Event()
- log.addHandler(api.APILoggingHandler(self.api_client))
+ self.http_session: Optional[aiohttp.ClientSession] = None
+ self.api_client = api.APIClient(loop=self.loop, connector=self._connector)
def add_cog(self, cog: commands.Cog) -> None:
"""Adds a "cog" to the bot and logs the operation."""
@@ -48,6 +51,47 @@ class Bot(commands.Bot):
async def start(self, *args, **kwargs) -> None:
"""Open an aiohttp session before logging in and connecting to Discord."""
- self.http_session = aiohttp.ClientSession(connector=self.connector)
+ self.http_session = aiohttp.ClientSession(connector=self._connector)
await super().start(*args, **kwargs)
+
+ async def on_guild_available(self, guild: discord.Guild) -> None:
+ """
+ Set the internal guild available event when constants.Guild.id becomes available.
+
+ If the cache appears to still be empty (no members, no channels, or no roles), the event
+ will not be set.
+ """
+ if guild.id != constants.Guild.id:
+ return
+
+ if not guild.roles or not guild.members or not guild.channels:
+ msg = "Guild available event was dispatched but the cache appears to still be empty!"
+ log.warning(msg)
+
+ try:
+ webhook = await self.fetch_webhook(constants.Webhooks.dev_log)
+ except discord.HTTPException as e:
+ log.error(f"Failed to fetch webhook to send empty cache warning: status {e.status}")
+ else:
+ await webhook.send(f"<@&{constants.Roles.admin}> {msg}")
+
+ return
+
+ self._guild_available.set()
+
+ async def on_guild_unavailable(self, guild: discord.Guild) -> None:
+ """Clear the internal guild available event when constants.Guild.id becomes unavailable."""
+ if guild.id != constants.Guild.id:
+ return
+
+ self._guild_available.clear()
+
+ async def wait_until_guild_available(self) -> None:
+ """
+ Wait until the constants.Guild.id guild is available (and the cache is ready).
+
+ The on_ready event is inadequate because it only waits 2 seconds for a GUILD_CREATE
+ gateway event before giving up and thus not populating the cache for unavailable guilds.
+ """
+ await self._guild_available.wait()
diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py
index 28e3e5d96..9e9e81364 100644
--- a/bot/cogs/antimalware.py
+++ b/bot/cogs/antimalware.py
@@ -4,7 +4,7 @@ from discord import Embed, Message, NotFound
from discord.ext.commands import Cog
from bot.bot import Bot
-from bot.constants import AntiMalware as AntiMalwareConfig, Channels, URLs
+from bot.constants import AntiMalware as AntiMalwareConfig, Channels, STAFF_ROLES, URLs
log = logging.getLogger(__name__)
@@ -18,7 +18,13 @@ class AntiMalware(Cog):
@Cog.listener()
async def on_message(self, message: Message) -> None:
"""Identify messages with prohibited attachments."""
- if not message.attachments:
+ # Return when message don't have attachment and don't moderate DMs
+ if not message.attachments or not message.guild:
+ return
+
+ # Check if user is staff, if is, return
+ # Since we only care that roles exist to iterate over, check for the attr rather than a User/Member instance
+ if hasattr(message.author, "roles") and any(role.id in STAFF_ROLES for role in message.author.roles):
return
embed = Embed()
diff --git a/bot/cogs/antispam.py b/bot/cogs/antispam.py
index f67ef6f05..baa6b9459 100644
--- a/bot/cogs/antispam.py
+++ b/bot/cogs/antispam.py
@@ -123,7 +123,7 @@ class AntiSpam(Cog):
async def alert_on_validation_error(self) -> None:
"""Unloads the cog and alerts admins if configuration validation failed."""
- await self.bot.wait_until_ready()
+ await self.bot.wait_until_guild_available()
if self.validation_errors:
body = "**The following errors were encountered:**\n"
body += "\n".join(f"- {error}" for error in self.validation_errors.values())
diff --git a/bot/cogs/bot.py b/bot/cogs/bot.py
index 73b1e8f41..f17135877 100644
--- a/bot/cogs/bot.py
+++ b/bot/cogs/bot.py
@@ -34,13 +34,12 @@ class BotCog(Cog, name="Bot"):
Channels.help_5: 0,
Channels.help_6: 0,
Channels.help_7: 0,
- Channels.python: 0,
+ Channels.python_discussion: 0,
}
# These channels will also work, but will not be subject to cooldown
self.channel_whitelist = (
- Channels.bot,
- Channels.devtest,
+ Channels.bot_commands,
)
# Stores improperly formatted Python codeblock message ids and the corresponding bot message
diff --git a/bot/cogs/clean.py b/bot/cogs/clean.py
index 2104efe57..5cdf0b048 100644
--- a/bot/cogs/clean.py
+++ b/bot/cogs/clean.py
@@ -173,7 +173,7 @@ class Clean(Cog):
colour=Colour(Colours.soft_red),
title="Bulk message delete",
text=message,
- channel_id=Channels.modlog,
+ channel_id=Channels.mod_log,
)
@group(invoke_without_command=True, name="clean", aliases=["purge"])
diff --git a/bot/cogs/config_verifier.py b/bot/cogs/config_verifier.py
new file mode 100644
index 000000000..d72c6c22e
--- /dev/null
+++ b/bot/cogs/config_verifier.py
@@ -0,0 +1,40 @@
+import logging
+
+from discord.ext.commands import Cog
+
+from bot import constants
+from bot.bot import Bot
+
+
+log = logging.getLogger(__name__)
+
+
+class ConfigVerifier(Cog):
+ """Verify config on startup."""
+
+ def __init__(self, bot: Bot):
+ self.bot = bot
+ self.channel_verify_task = self.bot.loop.create_task(self.verify_channels())
+
+ async def verify_channels(self) -> None:
+ """
+ Verify channels.
+
+ If any channels in config aren't present in server, log them in a warning.
+ """
+ await self.bot.wait_until_guild_available()
+ server = self.bot.get_guild(constants.Guild.id)
+
+ server_channel_ids = {channel.id for channel in server.channels}
+ invalid_channels = [
+ channel_name for channel_name, channel_id in constants.Channels
+ if channel_id not in server_channel_ids
+ ]
+
+ if invalid_channels:
+ log.warning(f"Configured channels do not exist in server: {', '.join(invalid_channels)}.")
+
+
+def setup(bot: Bot) -> None:
+ """Load the ConfigVerifier cog."""
+ bot.add_cog(ConfigVerifier(bot))
diff --git a/bot/cogs/defcon.py b/bot/cogs/defcon.py
index 3e7350fcc..cc0f79fe8 100644
--- a/bot/cogs/defcon.py
+++ b/bot/cogs/defcon.py
@@ -59,7 +59,7 @@ class Defcon(Cog):
async def sync_settings(self) -> None:
"""On cog load, try to synchronize DEFCON settings to the API."""
- await self.bot.wait_until_ready()
+ await self.bot.wait_until_guild_available()
self.channel = await self.bot.fetch_channel(Channels.defcon)
try:
@@ -68,20 +68,20 @@ class Defcon(Cog):
except Exception: # Yikes!
log.exception("Unable to get DEFCON settings!")
- await self.bot.get_channel(Channels.devlog).send(
- f"<@&{Roles.admin}> **WARNING**: Unable to get DEFCON settings!"
+ await self.bot.get_channel(Channels.dev_log).send(
+ f"<@&{Roles.admins}> **WARNING**: Unable to get DEFCON settings!"
)
else:
if data["enabled"]:
self.enabled = True
self.days = timedelta(days=data["days"])
- log.warning(f"DEFCON enabled: {self.days.days} days")
+ log.info(f"DEFCON enabled: {self.days.days} days")
else:
self.enabled = False
self.days = timedelta(days=0)
- log.warning(f"DEFCON disabled")
+ log.info(f"DEFCON disabled")
await self.update_channel_topic()
@@ -118,7 +118,7 @@ class Defcon(Cog):
)
@group(name='defcon', aliases=('dc',), invoke_without_command=True)
- @with_role(Roles.admin, Roles.owner)
+ @with_role(Roles.admins, Roles.owners)
async def defcon_group(self, ctx: Context) -> None:
"""Check the DEFCON status or run a subcommand."""
await ctx.invoke(self.bot.get_command("help"), "defcon")
@@ -146,7 +146,7 @@ class Defcon(Cog):
await self.send_defcon_log(action, ctx.author, error)
@defcon_group.command(name='enable', aliases=('on', 'e'))
- @with_role(Roles.admin, Roles.owner)
+ @with_role(Roles.admins, Roles.owners)
async def enable_command(self, ctx: Context) -> None:
"""
Enable DEFCON mode. Useful in a pinch, but be sure you know what you're doing!
@@ -159,7 +159,7 @@ class Defcon(Cog):
await self.update_channel_topic()
@defcon_group.command(name='disable', aliases=('off', 'd'))
- @with_role(Roles.admin, Roles.owner)
+ @with_role(Roles.admins, Roles.owners)
async def disable_command(self, ctx: Context) -> None:
"""Disable DEFCON mode. Useful in a pinch, but be sure you know what you're doing!"""
self.enabled = False
@@ -167,7 +167,7 @@ class Defcon(Cog):
await self.update_channel_topic()
@defcon_group.command(name='status', aliases=('s',))
- @with_role(Roles.admin, Roles.owner)
+ @with_role(Roles.admins, Roles.owners)
async def status_command(self, ctx: Context) -> None:
"""Check the current status of DEFCON mode."""
embed = Embed(
@@ -179,7 +179,7 @@ class Defcon(Cog):
await ctx.send(embed=embed)
@defcon_group.command(name='days')
- @with_role(Roles.admin, Roles.owner)
+ @with_role(Roles.admins, Roles.owners)
async def days_command(self, ctx: Context, days: int) -> None:
"""Set how old an account must be to join the server, in days, with DEFCON mode enabled."""
self.days = timedelta(days=days)
diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py
index 6e7c00b6a..204cffb37 100644
--- a/bot/cogs/doc.py
+++ b/bot/cogs/doc.py
@@ -157,7 +157,7 @@ class Doc(commands.Cog):
async def init_refresh_inventory(self) -> None:
"""Refresh documentation inventory on cog initialization."""
- await self.bot.wait_until_ready()
+ await self.bot.wait_until_guild_available()
await self.refresh_inventory()
async def update_single(
diff --git a/bot/cogs/duck_pond.py b/bot/cogs/duck_pond.py
index 345d2856c..1f84a0609 100644
--- a/bot/cogs/duck_pond.py
+++ b/bot/cogs/duck_pond.py
@@ -22,7 +22,7 @@ class DuckPond(Cog):
async def fetch_webhook(self) -> None:
"""Fetches the webhook object, so we can post to it."""
- await self.bot.wait_until_ready()
+ await self.bot.wait_until_guild_available()
try:
self.webhook = await self.bot.fetch_webhook(self.webhook_id)
diff --git a/bot/cogs/error_handler.py b/bot/cogs/error_handler.py
index 52893b2ee..864450139 100644
--- a/bot/cogs/error_handler.py
+++ b/bot/cogs/error_handler.py
@@ -1,20 +1,9 @@
import contextlib
import logging
+import typing as t
-from discord.ext.commands import (
- BadArgument,
- BotMissingPermissions,
- CheckFailure,
- CommandError,
- CommandInvokeError,
- CommandNotFound,
- CommandOnCooldown,
- DisabledCommand,
- MissingPermissions,
- NoPrivateMessage,
- UserInputError,
-)
-from discord.ext.commands import Cog, Context
+from discord.ext.commands import Cog, Command, Context, errors
+from sentry_sdk import push_scope
from bot.api import ResponseCodeError
from bot.bot import Bot
@@ -31,126 +20,201 @@ class ErrorHandler(Cog):
self.bot = bot
@Cog.listener()
- async def on_command_error(self, ctx: Context, e: CommandError) -> None:
+ async def on_command_error(self, ctx: Context, e: errors.CommandError) -> None:
"""
Provide generic command error handling.
- Error handling is deferred to any local error handler, if present.
-
- Error handling emits a single error response, prioritized as follows:
- 1. If the name fails to match a command but matches a tag, the tag is invoked
- 2. Send a BadArgument error message to the invoking context & invoke the command's help
- 3. Send a UserInputError error message to the invoking context & invoke the command's help
- 4. Send a NoPrivateMessage error message to the invoking context
- 5. Send a BotMissingPermissions error message to the invoking context
- 6. Log a MissingPermissions error, no message is sent
- 7. Send a InChannelCheckFailure error message to the invoking context
- 8. Log CheckFailure, CommandOnCooldown, and DisabledCommand errors, no message is sent
- 9. For CommandInvokeErrors, response is based on the type of error:
- * 404: Error message is sent to the invoking context
- * 400: Log the resopnse JSON, no message is sent
- * 500 <= status <= 600: Error message is sent to the invoking context
- 10. Otherwise, handling is deferred to `handle_unexpected_error`
+ Error handling is deferred to any local error handler, if present. This is done by
+ checking for the presence of a `handled` attribute on the error.
+
+ Error handling emits a single error message in the invoking context `ctx` and a log message,
+ prioritised as follows:
+
+ 1. If the name fails to match a command but matches a tag, the tag is invoked
+ * If CommandNotFound is raised when invoking the tag (determined by the presence of the
+ `invoked_from_error_handler` attribute), this error is treated as being unexpected
+ and therefore sends an error message
+ * Commands in the verification channel are ignored
+ 2. UserInputError: see `handle_user_input_error`
+ 3. CheckFailure: see `handle_check_failure`
+ 4. CommandOnCooldown: send an error message in the invoking context
+ 5. ResponseCodeError: see `handle_api_error`
+ 6. Otherwise, if not a DisabledCommand, handling is deferred to `handle_unexpected_error`
"""
command = ctx.command
- parent = None
+ if hasattr(e, "handled"):
+ log.trace(f"Command {command} had its error already handled locally; ignoring.")
+ return
+
+ # Try to look for a tag with the command's name if the command isn't found.
+ if isinstance(e, errors.CommandNotFound) and not hasattr(ctx, "invoked_from_error_handler"):
+ if ctx.channel.id != Channels.verification:
+ await self.try_get_tag(ctx)
+ return # Exit early to avoid logging.
+ elif isinstance(e, errors.UserInputError):
+ await self.handle_user_input_error(ctx, e)
+ elif isinstance(e, errors.CheckFailure):
+ await self.handle_check_failure(ctx, e)
+ elif isinstance(e, errors.CommandOnCooldown):
+ await ctx.send(e)
+ elif isinstance(e, errors.CommandInvokeError):
+ if isinstance(e.original, ResponseCodeError):
+ await self.handle_api_error(ctx, e.original)
+ else:
+ await self.handle_unexpected_error(ctx, e.original)
+ return # Exit early to avoid logging.
+ elif not isinstance(e, errors.DisabledCommand):
+ # ConversionError, MaxConcurrencyReached, ExtensionError
+ await self.handle_unexpected_error(ctx, e)
+ return # Exit early to avoid logging.
+
+ log.debug(
+ f"Command {command} invoked by {ctx.message.author} with error "
+ f"{e.__class__.__name__}: {e}"
+ )
+
+ async def get_help_command(self, command: t.Optional[Command]) -> t.Tuple:
+ """Return the help command invocation args to display help for `command`."""
+ parent = None
if command is not None:
parent = command.parent
# Retrieve the help command for the invoked command.
if parent and command:
- help_command = (self.bot.get_command("help"), parent.name, command.name)
+ return self.bot.get_command("help"), parent.name, command.name
elif command:
- help_command = (self.bot.get_command("help"), command.name)
+ return self.bot.get_command("help"), command.name
else:
- help_command = (self.bot.get_command("help"),)
+ return self.bot.get_command("help")
- if hasattr(e, "handled"):
- log.trace(f"Command {command} had its error already handled locally; ignoring.")
+ async def try_get_tag(self, ctx: Context) -> None:
+ """
+ Attempt to display a tag by interpreting the command name as a tag name.
+
+ The invocation of tags get respects its checks. Any CommandErrors raised will be handled
+ by `on_command_error`, but the `invoked_from_error_handler` attribute will be added to
+ the context to prevent infinite recursion in the case of a CommandNotFound exception.
+ """
+ tags_get_command = self.bot.get_command("tags get")
+ ctx.invoked_from_error_handler = True
+
+ log_msg = "Cancelling attempt to fall back to a tag due to failed checks."
+ try:
+ if not await tags_get_command.can_run(ctx):
+ log.debug(log_msg)
+ return
+ except errors.CommandError as tag_error:
+ log.debug(log_msg)
+ await self.on_command_error(ctx, tag_error)
return
- # Try to look for a tag with the command's name if the command isn't found.
- if isinstance(e, CommandNotFound) and not hasattr(ctx, "invoked_from_error_handler"):
- if not ctx.channel.id == Channels.verification:
- tags_get_command = self.bot.get_command("tags get")
- ctx.invoked_from_error_handler = True
-
- log_msg = "Cancelling attempt to fall back to a tag due to failed checks."
- try:
- if not await tags_get_command.can_run(ctx):
- log.debug(log_msg)
- return
- except CommandError as tag_error:
- log.debug(log_msg)
- await self.on_command_error(ctx, tag_error)
- return
-
- # Return to not raise the exception
- with contextlib.suppress(ResponseCodeError):
- await ctx.invoke(tags_get_command, tag_name=ctx.invoked_with)
- return
- elif isinstance(e, BadArgument):
+ # Return to not raise the exception
+ with contextlib.suppress(ResponseCodeError):
+ await ctx.invoke(tags_get_command, tag_name=ctx.invoked_with)
+ return
+
+ async def handle_user_input_error(self, ctx: Context, e: errors.UserInputError) -> None:
+ """
+ Send an error message in `ctx` for UserInputError, sometimes invoking the help command too.
+
+ * MissingRequiredArgument: send an error message with arg name and the help command
+ * TooManyArguments: send an error message and the help command
+ * BadArgument: send an error message and the help command
+ * BadUnionArgument: send an error message including the error produced by the last converter
+ * ArgumentParsingError: send an error message
+ * Other: send an error message and the help command
+ """
+ # TODO: use ctx.send_help() once PR #519 is merged.
+ help_command = await self.get_help_command(ctx.command)
+
+ if isinstance(e, errors.MissingRequiredArgument):
+ await ctx.send(f"Missing required argument `{e.param.name}`.")
+ await ctx.invoke(*help_command)
+ elif isinstance(e, errors.TooManyArguments):
+ await ctx.send(f"Too many arguments provided.")
+ await ctx.invoke(*help_command)
+ elif isinstance(e, errors.BadArgument):
await ctx.send(f"Bad argument: {e}\n")
await ctx.invoke(*help_command)
- elif isinstance(e, UserInputError):
+ elif isinstance(e, errors.BadUnionArgument):
+ await ctx.send(f"Bad argument: {e}\n```{e.errors[-1]}```")
+ elif isinstance(e, errors.ArgumentParsingError):
+ await ctx.send(f"Argument parsing error: {e}")
+ else:
await ctx.send("Something about your input seems off. Check the arguments:")
await ctx.invoke(*help_command)
- log.debug(
- f"Command {command} invoked by {ctx.message.author} with error "
- f"{e.__class__.__name__}: {e}"
- )
- elif isinstance(e, NoPrivateMessage):
- await ctx.send("Sorry, this command can't be used in a private message!")
- elif isinstance(e, BotMissingPermissions):
- await ctx.send(f"Sorry, it looks like I don't have the permissions I need to do that.")
- log.warning(
- f"The bot is missing permissions to execute command {command}: {e.missing_perms}"
- )
- elif isinstance(e, MissingPermissions):
- log.debug(
- f"{ctx.message.author} is missing permissions to invoke command {command}: "
- f"{e.missing_perms}"
+
+ @staticmethod
+ async def handle_check_failure(ctx: Context, e: errors.CheckFailure) -> None:
+ """
+ Send an error message in `ctx` for certain types of CheckFailure.
+
+ The following types are handled:
+
+ * BotMissingPermissions
+ * BotMissingRole
+ * BotMissingAnyRole
+ * NoPrivateMessage
+ * InChannelCheckFailure
+ """
+ bot_missing_errors = (
+ errors.BotMissingPermissions,
+ errors.BotMissingRole,
+ errors.BotMissingAnyRole
+ )
+
+ if isinstance(e, bot_missing_errors):
+ await ctx.send(
+ f"Sorry, it looks like I don't have the permissions or roles I need to do that."
)
- elif isinstance(e, InChannelCheckFailure):
+ elif isinstance(e, (InChannelCheckFailure, errors.NoPrivateMessage)):
await ctx.send(e)
- elif isinstance(e, (CheckFailure, CommandOnCooldown, DisabledCommand)):
- log.debug(
- f"Command {command} invoked by {ctx.message.author} with error "
- f"{e.__class__.__name__}: {e}"
- )
- elif isinstance(e, CommandInvokeError):
- if isinstance(e.original, ResponseCodeError):
- status = e.original.response.status
-
- if status == 404:
- await ctx.send("There does not seem to be anything matching your query.")
- elif status == 400:
- content = await e.original.response.json()
- log.debug(f"API responded with 400 for command {command}: %r.", content)
- await ctx.send("According to the API, your request is malformed.")
- elif 500 <= status < 600:
- await ctx.send("Sorry, there seems to be an internal issue with the API.")
- log.warning(f"API responded with {status} for command {command}")
- else:
- await ctx.send(f"Got an unexpected status code from the API (`{status}`).")
- log.warning(f"Unexpected API response for command {command}: {status}")
- else:
- await self.handle_unexpected_error(ctx, e.original)
+
+ @staticmethod
+ async def handle_api_error(ctx: Context, e: ResponseCodeError) -> None:
+ """Send an error message in `ctx` for ResponseCodeError and log it."""
+ if e.status == 404:
+ await ctx.send("There does not seem to be anything matching your query.")
+ log.debug(f"API responded with 404 for command {ctx.command}")
+ elif e.status == 400:
+ content = await e.response.json()
+ log.debug(f"API responded with 400 for command {ctx.command}: %r.", content)
+ await ctx.send("According to the API, your request is malformed.")
+ elif 500 <= e.status < 600:
+ await ctx.send("Sorry, there seems to be an internal issue with the API.")
+ log.warning(f"API responded with {e.status} for command {ctx.command}")
else:
- await self.handle_unexpected_error(ctx, e)
+ await ctx.send(f"Got an unexpected status code from the API (`{e.status}`).")
+ log.warning(f"Unexpected API response for command {ctx.command}: {e.status}")
@staticmethod
- async def handle_unexpected_error(ctx: Context, e: CommandError) -> None:
- """Generic handler for errors without an explicit handler."""
+ async def handle_unexpected_error(ctx: Context, e: errors.CommandError) -> None:
+ """Send a generic error message in `ctx` and log the exception as an error with exc_info."""
await ctx.send(
f"Sorry, an unexpected error occurred. Please let us know!\n\n"
f"```{e.__class__.__name__}: {e}```"
)
- log.error(
- f"Error executing command invoked by {ctx.message.author}: {ctx.message.content}"
- )
- raise e
+
+ with push_scope() as scope:
+ scope.user = {
+ "id": ctx.author.id,
+ "username": str(ctx.author)
+ }
+
+ scope.set_tag("command", ctx.command.qualified_name)
+ scope.set_tag("message_id", ctx.message.id)
+ scope.set_tag("channel_id", ctx.channel.id)
+
+ scope.set_extra("full_message", ctx.message.content)
+
+ if ctx.guild is not None:
+ scope.set_extra(
+ "jump_to",
+ f"https://discordapp.com/channels/{ctx.guild.id}/{ctx.channel.id}/{ctx.message.id}"
+ )
+
+ log.error(f"Error executing command invoked by {ctx.message.author}: {ctx.message.content}", exc_info=e)
def setup(bot: Bot) -> None:
diff --git a/bot/cogs/eval.py b/bot/cogs/eval.py
index 9c729f28a..52136fc8d 100644
--- a/bot/cogs/eval.py
+++ b/bot/cogs/eval.py
@@ -174,14 +174,14 @@ async def func(): # (None,) -> Any
await ctx.send(f"```py\n{out}```", embed=embed)
@group(name='internal', aliases=('int',))
- @with_role(Roles.owner, Roles.admin)
+ @with_role(Roles.owners, Roles.admins)
async def internal_group(self, ctx: Context) -> None:
"""Internal commands. Top secret!"""
if not ctx.invoked_subcommand:
await ctx.invoke(self.bot.get_command("help"), "internal")
@internal_group.command(name='eval', aliases=('e',))
- @with_role(Roles.admin, Roles.owner)
+ @with_role(Roles.admins, Roles.owners)
async def eval(self, ctx: Context, *, code: str) -> None:
"""Run eval in a REPL-like format."""
code = code.strip("`")
diff --git a/bot/cogs/extensions.py b/bot/cogs/extensions.py
index f16e79fb7..b312e1a1d 100644
--- a/bot/cogs/extensions.py
+++ b/bot/cogs/extensions.py
@@ -221,7 +221,7 @@ class Extensions(commands.Cog):
# This cannot be static (must have a __func__ attribute).
def cog_check(self, ctx: Context) -> bool:
"""Only allow moderators and core developers to invoke the commands in this cog."""
- return with_role_check(ctx, *MODERATION_ROLES, Roles.core_developer)
+ return with_role_check(ctx, *MODERATION_ROLES, Roles.core_developers)
# This cannot be static (must have a __func__ attribute).
async def cog_command_error(self, ctx: Context, error: Exception) -> None:
diff --git a/bot/cogs/free.py b/bot/cogs/free.py
index 49cab6172..02c02d067 100644
--- a/bot/cogs/free.py
+++ b/bot/cogs/free.py
@@ -22,7 +22,7 @@ class Free(Cog):
PYTHON_HELP_ID = Categories.python_help
@command(name="free", aliases=('f',))
- @redirect_output(destination_channel=Channels.bot, bypass_roles=STAFF_ROLES)
+ @redirect_output(destination_channel=Channels.bot_commands, bypass_roles=STAFF_ROLES)
async def free(self, ctx: Context, user: Member = None, seek: int = 2) -> None:
"""
Lists free help channels by likeliness of availability.
diff --git a/bot/cogs/help.py b/bot/cogs/help.py
index fd5bbc3ca..744722220 100644
--- a/bot/cogs/help.py
+++ b/bot/cogs/help.py
@@ -507,7 +507,7 @@ class Help(DiscordCog):
"""Custom Embed Pagination Help feature."""
@commands.command('help')
- @redirect_output(destination_channel=Channels.bot, bypass_roles=STAFF_ROLES)
+ @redirect_output(destination_channel=Channels.bot_commands, bypass_roles=STAFF_ROLES)
async def new_help(self, ctx: Context, *commands) -> None:
"""Shows Command Help."""
try:
diff --git a/bot/cogs/information.py b/bot/cogs/information.py
index 125d7ce24..49beca15b 100644
--- a/bot/cogs/information.py
+++ b/bot/cogs/information.py
@@ -2,14 +2,12 @@ import colorsys
import logging
import pprint
import textwrap
-import typing
-from collections import defaultdict
-from typing import Any, Mapping, Optional
-
-import discord
-from discord import CategoryChannel, Colour, Embed, Member, Role, TextChannel, VoiceChannel, utils
-from discord.ext import commands
-from discord.ext.commands import BucketType, Cog, Context, command, group
+from collections import Counter, defaultdict
+from string import Template
+from typing import Any, Mapping, Optional, Union
+
+from discord import Colour, Embed, Member, Message, Role, Status, utils
+from discord.ext.commands import BucketType, Cog, Context, Paginator, command, group
from discord.utils import escape_markdown
from bot import constants
@@ -32,8 +30,7 @@ class Information(Cog):
async def roles_info(self, ctx: Context) -> None:
"""Returns a list of all roles and their corresponding IDs."""
# Sort the roles alphabetically and remove the @everyone role
- roles = sorted(ctx.guild.roles, key=lambda role: role.name)
- roles = [role for role in roles if role.name != "@everyone"]
+ roles = sorted(ctx.guild.roles[1:], key=lambda role: role.name)
# Build a string
role_string = ""
@@ -46,20 +43,20 @@ class Information(Cog):
colour=Colour.blurple(),
description=role_string
)
-
embed.set_footer(text=f"Total roles: {len(roles)}")
await ctx.send(embed=embed)
@with_role(*constants.MODERATION_ROLES)
@command(name="role")
- async def role_info(self, ctx: Context, *roles: typing.Union[Role, str]) -> None:
+ async def role_info(self, ctx: Context, *roles: Union[Role, str]) -> None:
"""
Return information on a role or list of roles.
To specify multiple roles just add to the arguments, delimit roles with spaces in them using quotation marks.
"""
parsed_roles = []
+ failed_roles = []
for role_name in roles:
if isinstance(role_name, Role):
@@ -70,29 +67,29 @@ class Information(Cog):
role = utils.find(lambda r: r.name.lower() == role_name.lower(), ctx.guild.roles)
if not role:
- await ctx.send(f":x: Could not convert `{role_name}` to a role")
+ failed_roles.append(role_name)
continue
parsed_roles.append(role)
+ if failed_roles:
+ await ctx.send(
+ ":x: I could not convert the following role names to a role: \n- "
+ "\n- ".join(failed_roles)
+ )
+
for role in parsed_roles:
+ h, s, v = colorsys.rgb_to_hsv(*role.colour.to_rgb())
+
embed = Embed(
title=f"{role.name} info",
colour=role.colour,
)
-
embed.add_field(name="ID", value=role.id, inline=True)
-
embed.add_field(name="Colour (RGB)", value=f"#{role.colour.value:0>6x}", inline=True)
-
- h, s, v = colorsys.rgb_to_hsv(*role.colour.to_rgb())
-
embed.add_field(name="Colour (HSV)", value=f"{h:.2f} {s:.2f} {v}", inline=True)
-
embed.add_field(name="Member count", value=len(role.members), inline=True)
-
embed.add_field(name="Position", value=role.position)
-
embed.add_field(name="Permission code", value=role.permissions.value, inline=True)
await ctx.send(embed=embed)
@@ -104,40 +101,23 @@ class Information(Cog):
features = ", ".join(ctx.guild.features)
region = ctx.guild.region
- # How many of each type of channel?
roles = len(ctx.guild.roles)
- channels = ctx.guild.channels
- text_channels = 0
- category_channels = 0
- voice_channels = 0
- for channel in channels:
- if type(channel) == TextChannel:
- text_channels += 1
- elif type(channel) == CategoryChannel:
- category_channels += 1
- elif type(channel) == VoiceChannel:
- voice_channels += 1
-
- # How many of each user status?
member_count = ctx.guild.member_count
- members = ctx.guild.members
- online = 0
- dnd = 0
- idle = 0
- offline = 0
- for member in members:
- if str(member.status) == "online":
- online += 1
- elif str(member.status) == "offline":
- offline += 1
- elif str(member.status) == "idle":
- idle += 1
- elif str(member.status) == "dnd":
- dnd += 1
- embed = Embed(
- colour=Colour.blurple(),
- description=textwrap.dedent(f"""
+ # How many of each type of channel?
+ channels = Counter(c.type for c in ctx.guild.channels)
+ channel_counts = "".join(sorted(f"{str(ch).title()} channels: {channels[ch]}\n" for ch in channels)).strip()
+
+ # How many of each user status?
+ statuses = Counter(member.status for member in ctx.guild.members)
+ embed = Embed(colour=Colour.blurple())
+
+ # Because channel_counts lacks leading whitespace, it breaks the dedent if it's inserted directly by the
+ # f-string. While this is correctly formated by Discord, it makes unit testing difficult. To keep the formatting
+ # without joining a tuple of strings we can use a Template string to insert the already-formatted channel_counts
+ # after the dedent is made.
+ embed.description = Template(
+ textwrap.dedent(f"""
**Server information**
Created: {created}
Voice region: {region}
@@ -146,18 +126,15 @@ class Information(Cog):
**Counts**
Members: {member_count:,}
Roles: {roles}
- Text: {text_channels}
- Voice: {voice_channels}
- Channel categories: {category_channels}
+ $channel_counts
**Members**
- {constants.Emojis.status_online} {online}
- {constants.Emojis.status_idle} {idle}
- {constants.Emojis.status_dnd} {dnd}
- {constants.Emojis.status_offline} {offline}
+ {constants.Emojis.status_online} {statuses[Status.online]:,}
+ {constants.Emojis.status_idle} {statuses[Status.idle]:,}
+ {constants.Emojis.status_dnd} {statuses[Status.dnd]:,}
+ {constants.Emojis.status_offline} {statuses[Status.offline]:,}
""")
- )
-
+ ).substitute({"channel_counts": channel_counts})
embed.set_thumbnail(url=ctx.guild.icon_url)
await ctx.send(embed=embed)
@@ -169,14 +146,14 @@ class Information(Cog):
user = ctx.author
# Do a role check if this is being executed on someone other than the caller
- if user != ctx.author and not with_role_check(ctx, *constants.MODERATION_ROLES):
+ elif user != ctx.author and not with_role_check(ctx, *constants.MODERATION_ROLES):
await ctx.send("You may not use this command on users other than yourself.")
return
# Non-staff may only do this in #bot-commands
if not with_role_check(ctx, *constants.STAFF_ROLES):
- if not ctx.channel.id == constants.Channels.bot:
- raise InChannelCheckFailure(constants.Channels.bot)
+ if not ctx.channel.id == constants.Channels.bot_commands:
+ raise InChannelCheckFailure(constants.Channels.bot_commands)
embed = await self.create_user_embed(ctx, user)
@@ -202,7 +179,7 @@ class Information(Cog):
name = f"{user.nick} ({name})"
joined = time_since(user.joined_at, precision="days")
- roles = ", ".join(role.mention for role in user.roles if role.name != "@everyone")
+ roles = ", ".join(role.mention for role in user.roles[1:])
description = [
textwrap.dedent(f"""
@@ -355,14 +332,14 @@ class Information(Cog):
@cooldown_with_role_bypass(2, 60 * 3, BucketType.member, bypass_roles=constants.STAFF_ROLES)
@group(invoke_without_command=True)
- @in_channel(constants.Channels.bot, bypass_roles=constants.STAFF_ROLES)
- async def raw(self, ctx: Context, *, message: discord.Message, json: bool = False) -> None:
+ @in_channel(constants.Channels.bot_commands, bypass_roles=constants.STAFF_ROLES)
+ async def raw(self, ctx: Context, *, message: Message, json: bool = False) -> None:
"""Shows information about the raw API response."""
# I *guess* it could be deleted right as the command is invoked but I felt like it wasn't worth handling
# doing this extra request is also much easier than trying to convert everything back into a dictionary again
raw_data = await ctx.bot.http.get_message(message.channel.id, message.id)
- paginator = commands.Paginator()
+ paginator = Paginator()
def add_content(title: str, content: str) -> None:
paginator.add_line(f'== {title} ==\n')
@@ -390,7 +367,7 @@ class Information(Cog):
await ctx.send(page)
@raw.command()
- async def json(self, ctx: Context, message: discord.Message) -> None:
+ async def json(self, ctx: Context, message: Message) -> None:
"""Shows information about the raw API response in a copy-pasteable Python format."""
await ctx.invoke(self.raw, message=message, json=True)
diff --git a/bot/cogs/jams.py b/bot/cogs/jams.py
index 985f28ce5..1d062b0c2 100644
--- a/bot/cogs/jams.py
+++ b/bot/cogs/jams.py
@@ -18,7 +18,7 @@ class CodeJams(commands.Cog):
self.bot = bot
@commands.command()
- @with_role(Roles.admin)
+ @with_role(Roles.admins)
async def createteam(self, ctx: commands.Context, team_name: str, members: commands.Greedy[Member]) -> None:
"""
Create team channels (voice and text) in the Code Jams category, assign roles, and add overwrites for the team.
@@ -95,10 +95,10 @@ class CodeJams(commands.Cog):
)
# Assign team leader role
- await members[0].add_roles(ctx.guild.get_role(Roles.team_leader))
+ await members[0].add_roles(ctx.guild.get_role(Roles.team_leaders))
# Assign rest of roles
- jammer_role = ctx.guild.get_role(Roles.jammer)
+ jammer_role = ctx.guild.get_role(Roles.jammers)
for member in members:
await member.add_roles(jammer_role)
diff --git a/bot/cogs/logging.py b/bot/cogs/logging.py
index d1b7dcab3..94fa2b139 100644
--- a/bot/cogs/logging.py
+++ b/bot/cogs/logging.py
@@ -20,7 +20,7 @@ class Logging(Cog):
async def startup_greeting(self) -> None:
"""Announce our presence to the configured devlog channel."""
- await self.bot.wait_until_ready()
+ await self.bot.wait_until_guild_available()
log.info("Bot connected!")
embed = Embed(description="Connected!")
@@ -34,7 +34,7 @@ class Logging(Cog):
)
if not DEBUG_MODE:
- await self.bot.get_channel(Channels.devlog).send(embed=embed)
+ await self.bot.get_channel(Channels.dev_log).send(embed=embed)
def setup(bot: Bot) -> None:
diff --git a/bot/cogs/moderation/infractions.py b/bot/cogs/moderation/infractions.py
index f4e296df9..9ea17b2b3 100644
--- a/bot/cogs/moderation/infractions.py
+++ b/bot/cogs/moderation/infractions.py
@@ -313,6 +313,6 @@ class Infractions(InfractionScheduler, commands.Cog):
async def cog_command_error(self, ctx: Context, error: Exception) -> None:
"""Send a notification to the invoking context on a Union failure."""
if isinstance(error, commands.BadUnionArgument):
- if discord.User in error.converters:
+ if discord.User in error.converters or discord.Member in error.converters:
await ctx.send(str(error.errors[0]))
error.handled = True
diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py
index f2964cd78..f74089056 100644
--- a/bot/cogs/moderation/management.py
+++ b/bot/cogs/moderation/management.py
@@ -129,7 +129,9 @@ class ModManagement(commands.Cog):
# Re-schedule infraction if the expiration has been updated
if 'expires_at' in request_data:
- self.infractions_cog.cancel_task(new_infraction['id'])
+ # A scheduled task should only exist if the old infraction wasn't permanent
+ if old_infraction['expires_at']:
+ self.infractions_cog.cancel_task(new_infraction['id'])
# If the infraction was not marked as permanent, schedule a new expiration task
if request_data['expires_at']:
diff --git a/bot/cogs/moderation/modlog.py b/bot/cogs/moderation/modlog.py
index e8ae0dbe6..59ae6b587 100644
--- a/bot/cogs/moderation/modlog.py
+++ b/bot/cogs/moderation/modlog.py
@@ -87,7 +87,7 @@ class ModLog(Cog, name="ModLog"):
title: t.Optional[str],
text: str,
thumbnail: t.Optional[t.Union[str, discord.Asset]] = None,
- channel_id: int = Channels.modlog,
+ channel_id: int = Channels.mod_log,
ping_everyone: bool = False,
files: t.Optional[t.List[discord.File]] = None,
content: t.Optional[str] = None,
@@ -377,7 +377,7 @@ class ModLog(Cog, name="ModLog"):
Icons.user_ban, Colours.soft_red,
"User banned", f"{member} (`{member.id}`)",
thumbnail=member.avatar_url_as(static_format="png"),
- channel_id=Channels.userlog
+ channel_id=Channels.user_log
)
@Cog.listener()
@@ -399,7 +399,7 @@ class ModLog(Cog, name="ModLog"):
Icons.sign_in, Colours.soft_green,
"User joined", message,
thumbnail=member.avatar_url_as(static_format="png"),
- channel_id=Channels.userlog
+ channel_id=Channels.user_log
)
@Cog.listener()
@@ -416,7 +416,7 @@ class ModLog(Cog, name="ModLog"):
Icons.sign_out, Colours.soft_red,
"User left", f"{member} (`{member.id}`)",
thumbnail=member.avatar_url_as(static_format="png"),
- channel_id=Channels.userlog
+ channel_id=Channels.user_log
)
@Cog.listener()
@@ -433,7 +433,7 @@ class ModLog(Cog, name="ModLog"):
Icons.user_unban, Colour.blurple(),
"User unbanned", f"{member} (`{member.id}`)",
thumbnail=member.avatar_url_as(static_format="png"),
- channel_id=Channels.modlog
+ channel_id=Channels.mod_log
)
@Cog.listener()
@@ -529,7 +529,7 @@ class ModLog(Cog, name="ModLog"):
Icons.user_update, Colour.blurple(),
"Member updated", message,
thumbnail=after.avatar_url_as(static_format="png"),
- channel_id=Channels.userlog
+ channel_id=Channels.user_log
)
@Cog.listener()
@@ -538,7 +538,7 @@ class ModLog(Cog, name="ModLog"):
channel = message.channel
author = message.author
- if message.guild.id != GuildConstant.id or channel.id in GuildConstant.ignored:
+ if message.guild.id != GuildConstant.id or channel.id in GuildConstant.modlog_blacklist:
return
self._cached_deletes.append(message.id)
@@ -591,7 +591,7 @@ class ModLog(Cog, name="ModLog"):
@Cog.listener()
async def on_raw_message_delete(self, event: discord.RawMessageDeleteEvent) -> None:
"""Log raw message delete event to message change log."""
- if event.guild_id != GuildConstant.id or event.channel_id in GuildConstant.ignored:
+ if event.guild_id != GuildConstant.id or event.channel_id in GuildConstant.modlog_blacklist:
return
await asyncio.sleep(1) # Wait here in case the normal event was fired
@@ -635,7 +635,7 @@ class ModLog(Cog, name="ModLog"):
if (
not msg_before.guild
or msg_before.guild.id != GuildConstant.id
- or msg_before.channel.id in GuildConstant.ignored
+ or msg_before.channel.id in GuildConstant.modlog_blacklist
or msg_before.author.bot
):
return
@@ -717,7 +717,7 @@ class ModLog(Cog, name="ModLog"):
if (
not message.guild
or message.guild.id != GuildConstant.id
- or message.channel.id in GuildConstant.ignored
+ or message.channel.id in GuildConstant.modlog_blacklist
or message.author.bot
):
return
@@ -769,7 +769,7 @@ class ModLog(Cog, name="ModLog"):
"""Log member voice state changes to the voice log channel."""
if (
member.guild.id != GuildConstant.id
- or (before.channel and before.channel.id in GuildConstant.ignored)
+ or (before.channel and before.channel.id in GuildConstant.modlog_blacklist)
):
return
diff --git a/bot/cogs/moderation/scheduler.py b/bot/cogs/moderation/scheduler.py
index e14c302cb..45cf5ec8a 100644
--- a/bot/cogs/moderation/scheduler.py
+++ b/bot/cogs/moderation/scheduler.py
@@ -38,7 +38,7 @@ class InfractionScheduler(Scheduler):
async def reschedule_infractions(self, supported_infractions: t.Container[str]) -> None:
"""Schedule expiration for previous infractions."""
- await self.bot.wait_until_ready()
+ await self.bot.wait_until_guild_available()
log.trace(f"Rescheduling infractions for {self.__class__.__name__}.")
@@ -307,18 +307,25 @@ class InfractionScheduler(Scheduler):
Infractions of unsupported types will raise a ValueError.
"""
guild = self.bot.get_guild(constants.Guild.id)
- mod_role = guild.get_role(constants.Roles.moderator)
+ mod_role = guild.get_role(constants.Roles.moderators)
user_id = infraction["user"]
+ actor = infraction["actor"]
type_ = infraction["type"]
id_ = infraction["id"]
+ inserted_at = infraction["inserted_at"]
+ expiry = infraction["expires_at"]
log.info(f"Marking infraction #{id_} as inactive (expired).")
+ expiry = dateutil.parser.isoparse(expiry).replace(tzinfo=None) if expiry else None
+ created = time.format_infraction_with_duration(inserted_at, expiry)
+
log_content = None
log_text = {
- "Member": str(user_id),
- "Actor": str(self.bot.user),
- "Reason": infraction["reason"]
+ "Member": f"<@{user_id}>",
+ "Actor": str(self.bot.get_user(actor) or actor),
+ "Reason": infraction["reason"],
+ "Created": created,
}
try:
@@ -384,14 +391,19 @@ class InfractionScheduler(Scheduler):
if send_log:
log_title = f"expiration failed" if "Failure" in log_text else "expired"
+ user = self.bot.get_user(user_id)
+ avatar = user.avatar_url_as(static_format="png") if user else None
+
log.trace(f"Sending deactivation mod log for infraction #{id_}.")
await self.mod_log.send_log_message(
icon_url=utils.INFRACTION_ICONS[type_][1],
colour=Colours.soft_green,
title=f"Infraction {log_title}: {type_}",
+ thumbnail=avatar,
text="\n".join(f"{k}: {v}" for k, v in log_text.items()),
footer=f"ID: {id_}",
content=log_content,
+
)
return log_text
diff --git a/bot/cogs/moderation/superstarify.py b/bot/cogs/moderation/superstarify.py
index 050c847ac..c41874a95 100644
--- a/bot/cogs/moderation/superstarify.py
+++ b/bot/cogs/moderation/superstarify.py
@@ -109,7 +109,8 @@ class Superstarify(InfractionScheduler, Cog):
ctx: Context,
member: Member,
duration: Expiry,
- reason: str = None
+ *,
+ reason: str = None,
) -> None:
"""
Temporarily force a random superstar name (like Taylor Swift) to be the user's nickname.
diff --git a/bot/cogs/off_topic_names.py b/bot/cogs/off_topic_names.py
index bf777ea5a..81511f99d 100644
--- a/bot/cogs/off_topic_names.py
+++ b/bot/cogs/off_topic_names.py
@@ -88,7 +88,7 @@ class OffTopicNames(Cog):
async def init_offtopic_updater(self) -> None:
"""Start off-topic channel updating event loop if it hasn't already started."""
- await self.bot.wait_until_ready()
+ await self.bot.wait_until_guild_available()
if self.updater_task is None:
coro = update_names(self.bot)
self.updater_task = self.bot.loop.create_task(coro)
diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py
index aa487f18e..5a7fa100f 100644
--- a/bot/cogs/reddit.py
+++ b/bot/cogs/reddit.py
@@ -43,12 +43,12 @@ class Reddit(Cog):
def cog_unload(self) -> None:
"""Stop the loop task and revoke the access token when the cog is unloaded."""
self.auto_poster_loop.cancel()
- if self.access_token.expires_at < datetime.utcnow():
- self.revoke_access_token()
+ if self.access_token and self.access_token.expires_at > datetime.utcnow():
+ asyncio.create_task(self.revoke_access_token())
async def init_reddit_ready(self) -> None:
"""Sets the reddit webhook when the cog is loaded."""
- await self.bot.wait_until_ready()
+ await self.bot.wait_until_guild_available()
if not self.webhook:
self.webhook = await self.bot.fetch_webhook(Webhooks.reddit)
@@ -83,7 +83,7 @@ class Reddit(Cog):
expires_at=datetime.utcnow() + timedelta(seconds=expiration)
)
- log.debug(f"New token acquired; expires on {self.access_token.expires_at}")
+ log.debug(f"New token acquired; expires on UTC {self.access_token.expires_at}")
return
else:
log.debug(
@@ -208,7 +208,7 @@ class Reddit(Cog):
await asyncio.sleep(seconds_until)
- await self.bot.wait_until_ready()
+ await self.bot.wait_until_guild_available()
if not self.webhook:
await self.bot.fetch_webhook(Webhooks.reddit)
@@ -290,4 +290,7 @@ class Reddit(Cog):
def setup(bot: Bot) -> None:
"""Load the Reddit cog."""
+ if not RedditConfig.secret or not RedditConfig.client_id:
+ log.error("Credentials not provided, cog not loaded.")
+ return
bot.add_cog(Reddit(bot))
diff --git a/bot/cogs/reminders.py b/bot/cogs/reminders.py
index 45bf9a8f4..041791056 100644
--- a/bot/cogs/reminders.py
+++ b/bot/cogs/reminders.py
@@ -2,16 +2,17 @@ import asyncio
import logging
import random
import textwrap
+import typing as t
from datetime import datetime, timedelta
from operator import itemgetter
-from typing import Optional
+import discord
+from dateutil.parser import isoparse
from dateutil.relativedelta import relativedelta
-from discord import Colour, Embed, Message
from discord.ext.commands import Cog, Context, group
from bot.bot import Bot
-from bot.constants import Channels, Icons, NEGATIVE_REPLIES, POSITIVE_REPLIES, STAFF_ROLES
+from bot.constants import Guild, Icons, NEGATIVE_REPLIES, POSITIVE_REPLIES, STAFF_ROLES
from bot.converters import Duration
from bot.pagination import LinePaginator
from bot.utils.checks import without_role_check
@@ -20,7 +21,7 @@ from bot.utils.time import humanize_delta, wait_until
log = logging.getLogger(__name__)
-WHITELISTED_CHANNELS = (Channels.bot,)
+WHITELISTED_CHANNELS = Guild.reminder_whitelist
MAXIMUM_REMINDERS = 5
@@ -35,7 +36,7 @@ class Reminders(Scheduler, Cog):
async def reschedule_reminders(self) -> None:
"""Get all current reminders from the API and reschedule them."""
- await self.bot.wait_until_ready()
+ await self.bot.wait_until_guild_available()
response = await self.bot.api_client.get(
'bot/reminders',
params={'active': 'true'}
@@ -45,29 +46,64 @@ class Reminders(Scheduler, Cog):
loop = asyncio.get_event_loop()
for reminder in response:
- remind_at = datetime.fromisoformat(reminder['expiration'][:-1])
+ is_valid, *_ = self.ensure_valid_reminder(reminder, cancel_task=False)
+ if not is_valid:
+ continue
+
+ remind_at = isoparse(reminder['expiration']).replace(tzinfo=None)
# If the reminder is already overdue ...
if remind_at < now:
late = relativedelta(now, remind_at)
await self.send_reminder(reminder, late)
-
else:
self.schedule_task(loop, reminder["id"], reminder)
+ def ensure_valid_reminder(
+ self,
+ reminder: dict,
+ cancel_task: bool = True
+ ) -> t.Tuple[bool, discord.User, discord.TextChannel]:
+ """Ensure reminder author and channel can be fetched otherwise delete the reminder."""
+ user = self.bot.get_user(reminder['author'])
+ channel = self.bot.get_channel(reminder['channel_id'])
+ is_valid = True
+ if not user or not channel:
+ is_valid = False
+ log.info(
+ f"Reminder {reminder['id']} invalid: "
+ f"User {reminder['author']}={user}, Channel {reminder['channel_id']}={channel}."
+ )
+ asyncio.create_task(self._delete_reminder(reminder['id'], cancel_task))
+
+ return is_valid, user, channel
+
@staticmethod
- async def _send_confirmation(ctx: Context, on_success: str) -> None:
+ async def _send_confirmation(
+ ctx: Context,
+ on_success: str,
+ reminder_id: str,
+ delivery_dt: t.Optional[datetime],
+ ) -> None:
"""Send an embed confirming the reminder change was made successfully."""
- embed = Embed()
- embed.colour = Colour.green()
+ embed = discord.Embed()
+ embed.colour = discord.Colour.green()
embed.title = random.choice(POSITIVE_REPLIES)
embed.description = on_success
+
+ footer_str = f"ID: {reminder_id}"
+ if delivery_dt:
+ # Reminder deletion will have a `None` `delivery_dt`
+ footer_str = f"{footer_str}, Due: {delivery_dt.strftime('%Y-%m-%dT%H:%M:%S')}"
+
+ embed.set_footer(text=footer_str)
+
await ctx.send(embed=embed)
async def _scheduled_task(self, reminder: dict) -> None:
"""A coroutine which sends the reminder once the time is reached, and cancels the running task."""
reminder_id = reminder["id"]
- reminder_datetime = datetime.fromisoformat(reminder['expiration'][:-1])
+ reminder_datetime = isoparse(reminder['expiration']).replace(tzinfo=None)
# Send the reminder message once the desired duration has passed
await wait_until(reminder_datetime)
@@ -79,12 +115,13 @@ class Reminders(Scheduler, Cog):
# Now we can begone with it from our schedule list.
self.cancel_task(reminder_id)
- async def _delete_reminder(self, reminder_id: str) -> None:
+ async def _delete_reminder(self, reminder_id: str, cancel_task: bool = True) -> None:
"""Delete a reminder from the database, given its ID, and cancel the running task."""
await self.bot.api_client.delete('bot/reminders/' + str(reminder_id))
- # Now we can remove it from the schedule list
- self.cancel_task(reminder_id)
+ if cancel_task:
+ # Now we can remove it from the schedule list
+ self.cancel_task(reminder_id)
async def _reschedule_reminder(self, reminder: dict) -> None:
"""Reschedule a reminder object."""
@@ -95,11 +132,12 @@ class Reminders(Scheduler, Cog):
async def send_reminder(self, reminder: dict, late: relativedelta = None) -> None:
"""Send the reminder."""
- channel = self.bot.get_channel(reminder["channel_id"])
- user = self.bot.get_user(reminder["author"])
+ is_valid, user, channel = self.ensure_valid_reminder(reminder)
+ if not is_valid:
+ return
- embed = Embed()
- embed.colour = Colour.blurple()
+ embed = discord.Embed()
+ embed.colour = discord.Colour.blurple()
embed.set_author(
icon_url=Icons.remind_blurple,
name="It has arrived!"
@@ -111,7 +149,7 @@ class Reminders(Scheduler, Cog):
embed.description += f"\n[Jump back to when you created the reminder]({reminder['jump_url']})"
if late:
- embed.colour = Colour.red()
+ embed.colour = discord.Colour.red()
embed.set_author(
icon_url=Icons.remind_red,
name=f"Sorry it arrived {humanize_delta(late, max_units=2)} late!"
@@ -129,20 +167,20 @@ class Reminders(Scheduler, Cog):
await ctx.invoke(self.new_reminder, expiration=expiration, content=content)
@remind_group.command(name="new", aliases=("add", "create"))
- async def new_reminder(self, ctx: Context, expiration: Duration, *, content: str) -> Optional[Message]:
+ async def new_reminder(self, ctx: Context, expiration: Duration, *, content: str) -> t.Optional[discord.Message]:
"""
Set yourself a simple reminder.
Expiration is parsed per: http://strftime.org/
"""
- embed = Embed()
+ embed = discord.Embed()
# If the user is not staff, we need to verify whether or not to make a reminder at all.
if without_role_check(ctx, *STAFF_ROLES):
# If they don't have permission to set a reminder in this channel
if ctx.channel.id not in WHITELISTED_CHANNELS:
- embed.colour = Colour.red()
+ embed.colour = discord.Colour.red()
embed.title = random.choice(NEGATIVE_REPLIES)
embed.description = "Sorry, you can't do that here!"
@@ -159,7 +197,7 @@ class Reminders(Scheduler, Cog):
# Let's limit this, so we don't get 10 000
# reminders from kip or something like that :P
if len(active_reminders) > MAXIMUM_REMINDERS:
- embed.colour = Colour.red()
+ embed.colour = discord.Colour.red()
embed.title = random.choice(NEGATIVE_REPLIES)
embed.description = "You have too many active reminders!"
@@ -178,18 +216,21 @@ class Reminders(Scheduler, Cog):
)
now = datetime.utcnow() - timedelta(seconds=1)
+ humanized_delta = humanize_delta(relativedelta(expiration, now))
# Confirm to the user that it worked.
await self._send_confirmation(
ctx,
- on_success=f"Your reminder will arrive in {humanize_delta(relativedelta(expiration, now))}!"
+ on_success=f"Your reminder will arrive in {humanized_delta}!",
+ reminder_id=reminder["id"],
+ delivery_dt=expiration,
)
loop = asyncio.get_event_loop()
self.schedule_task(loop, reminder["id"], reminder)
@remind_group.command(name="list")
- async def list_reminders(self, ctx: Context) -> Optional[Message]:
+ async def list_reminders(self, ctx: Context) -> t.Optional[discord.Message]:
"""View a paginated embed of all reminders for your user."""
# Get all the user's reminders from the database.
data = await self.bot.api_client.get(
@@ -212,7 +253,7 @@ class Reminders(Scheduler, Cog):
for content, remind_at, id_ in reminders:
# Parse and humanize the time, make it pretty :D
- remind_datetime = datetime.fromisoformat(remind_at[:-1])
+ remind_datetime = isoparse(remind_at).replace(tzinfo=None)
time = humanize_delta(relativedelta(remind_datetime, now))
text = textwrap.dedent(f"""
@@ -222,8 +263,8 @@ class Reminders(Scheduler, Cog):
lines.append(text)
- embed = Embed()
- embed.colour = Colour.blurple()
+ embed = discord.Embed()
+ embed.colour = discord.Colour.blurple()
embed.title = f"Reminders for {ctx.author}"
# Remind the user that they have no reminders :^)
@@ -232,7 +273,7 @@ class Reminders(Scheduler, Cog):
return await ctx.send(embed=embed)
# Construct the embed and paginate it.
- embed.colour = Colour.blurple()
+ embed.colour = discord.Colour.blurple()
await LinePaginator.paginate(
lines,
@@ -261,7 +302,10 @@ class Reminders(Scheduler, Cog):
# Send a confirmation message to the channel
await self._send_confirmation(
- ctx, on_success="That reminder has been edited successfully!"
+ ctx,
+ on_success="That reminder has been edited successfully!",
+ reminder_id=id_,
+ delivery_dt=expiration,
)
await self._reschedule_reminder(reminder)
@@ -275,18 +319,27 @@ class Reminders(Scheduler, Cog):
json={'content': content}
)
+ # Parse the reminder expiration back into a datetime for the confirmation message
+ expiration = isoparse(reminder['expiration']).replace(tzinfo=None)
+
# Send a confirmation message to the channel
await self._send_confirmation(
- ctx, on_success="That reminder has been edited successfully!"
+ ctx,
+ on_success="That reminder has been edited successfully!",
+ reminder_id=id_,
+ delivery_dt=expiration,
)
await self._reschedule_reminder(reminder)
- @remind_group.command("delete", aliases=("remove",))
+ @remind_group.command("delete", aliases=("remove", "cancel"))
async def delete_reminder(self, ctx: Context, id_: int) -> None:
"""Delete one of your active reminders."""
await self._delete_reminder(id_)
await self._send_confirmation(
- ctx, on_success="That reminder has been deleted successfully!"
+ ctx,
+ on_success="That reminder has been deleted successfully!",
+ reminder_id=id_,
+ delivery_dt=None,
)
diff --git a/bot/cogs/snekbox.py b/bot/cogs/snekbox.py
index d52027ac6..44764e7e9 100644
--- a/bot/cogs/snekbox.py
+++ b/bot/cogs/snekbox.py
@@ -38,7 +38,7 @@ RAW_CODE_REGEX = re.compile(
)
MAX_PASTE_LEN = 1000
-EVAL_ROLES = (Roles.helpers, Roles.moderator, Roles.admin, Roles.owner, Roles.rockstars, Roles.partners)
+EVAL_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners)
SIGKILL = 9
@@ -245,7 +245,7 @@ class Snekbox(Cog):
@command(name="eval", aliases=("e",))
@guild_only()
- @in_channel(Channels.bot, hidden_channels=(Channels.esoteric,), bypass_roles=EVAL_ROLES)
+ @in_channel(Channels.bot_commands, hidden_channels=(Channels.esoteric,), bypass_roles=EVAL_ROLES)
async def eval_command(self, ctx: Context, *, code: str = None) -> None:
"""
Run Python code and get the results.
diff --git a/bot/cogs/sync/cog.py b/bot/cogs/sync/cog.py
index 4e6ed156b..5708be3f4 100644
--- a/bot/cogs/sync/cog.py
+++ b/bot/cogs/sync/cog.py
@@ -1,7 +1,7 @@
import logging
-from typing import Callable, Dict, Iterable, Union
+from typing import Any, Dict
-from discord import Guild, Member, Role, User
+from discord import Member, Role, User
from discord.ext import commands
from discord.ext.commands import Cog, Context
@@ -16,45 +16,28 @@ log = logging.getLogger(__name__)
class Sync(Cog):
"""Captures relevant events and sends them to the site."""
- # The server to synchronize events on.
- # Note that setting this wrongly will result in things getting deleted
- # that possibly shouldn't be.
- SYNC_SERVER_ID = constants.Guild.id
-
- # An iterable of callables that are called when the bot is ready.
- ON_READY_SYNCERS: Iterable[Callable[[Bot, Guild], None]] = (
- syncers.sync_roles,
- syncers.sync_users
- )
-
def __init__(self, bot: Bot) -> None:
self.bot = bot
+ self.role_syncer = syncers.RoleSyncer(self.bot)
+ self.user_syncer = syncers.UserSyncer(self.bot)
self.bot.loop.create_task(self.sync_guild())
async def sync_guild(self) -> None:
"""Syncs the roles/users of the guild with the database."""
- await self.bot.wait_until_ready()
- guild = self.bot.get_guild(self.SYNC_SERVER_ID)
- if guild is not None:
- for syncer in self.ON_READY_SYNCERS:
- syncer_name = syncer.__name__[5:] # drop off `sync_`
- log.info("Starting `%s` syncer.", syncer_name)
- total_created, total_updated, total_deleted = await syncer(self.bot, guild)
- if total_deleted is None:
- log.info(
- f"`{syncer_name}` syncer finished, created `{total_created}`, updated `{total_updated}`."
- )
- else:
- log.info(
- f"`{syncer_name}` syncer finished, created `{total_created}`, updated `{total_updated}`, "
- f"deleted `{total_deleted}`."
- )
-
- async def patch_user(self, user_id: int, updated_information: Dict[str, Union[str, int]]) -> None:
+ await self.bot.wait_until_guild_available()
+
+ guild = self.bot.get_guild(constants.Guild.id)
+ if guild is None:
+ return
+
+ for syncer in (self.role_syncer, self.user_syncer):
+ await syncer.sync(guild)
+
+ async def patch_user(self, user_id: int, updated_information: Dict[str, Any]) -> None:
"""Send a PATCH request to partially update a user in the database."""
try:
- await self.bot.api_client.patch("bot/users/" + str(user_id), json=updated_information)
+ await self.bot.api_client.patch(f"bot/users/{user_id}", json=updated_information)
except ResponseCodeError as e:
if e.response.status != 404:
raise
@@ -82,12 +65,14 @@ class Sync(Cog):
@Cog.listener()
async def on_guild_role_update(self, before: Role, after: Role) -> None:
"""Syncs role with the database if any of the stored attributes were updated."""
- if (
- before.name != after.name
- or before.colour != after.colour
- or before.permissions != after.permissions
- or before.position != after.position
- ):
+ was_updated = (
+ before.name != after.name
+ or before.colour != after.colour
+ or before.permissions != after.permissions
+ or before.position != after.position
+ )
+
+ if was_updated:
await self.bot.api_client.put(
f'bot/roles/{after.id}',
json={
@@ -137,18 +122,8 @@ class Sync(Cog):
@Cog.listener()
async def on_member_remove(self, member: Member) -> None:
- """Updates the user information when a member leaves the guild."""
- await self.bot.api_client.put(
- f'bot/users/{member.id}',
- json={
- 'avatar_hash': member.avatar,
- 'discriminator': int(member.discriminator),
- 'id': member.id,
- 'in_guild': False,
- 'name': member.name,
- 'roles': sorted(role.id for role in member.roles)
- }
- )
+ """Set the in_guild field to False when a member leaves the guild."""
+ await self.patch_user(member.id, updated_information={"in_guild": False})
@Cog.listener()
async def on_member_update(self, before: Member, after: Member) -> None:
@@ -160,7 +135,8 @@ class Sync(Cog):
@Cog.listener()
async def on_user_update(self, before: User, after: User) -> None:
"""Update the user information in the database if a relevant change is detected."""
- if any(getattr(before, attr) != getattr(after, attr) for attr in ("name", "discriminator", "avatar")):
+ attrs = ("name", "discriminator", "avatar")
+ if any(getattr(before, attr) != getattr(after, attr) for attr in attrs):
updated_information = {
"name": after.name,
"discriminator": int(after.discriminator),
@@ -176,25 +152,11 @@ class Sync(Cog):
@sync_group.command(name='roles')
@commands.has_permissions(administrator=True)
async def sync_roles_command(self, ctx: Context) -> None:
- """Manually synchronize the guild's roles with the roles on the site."""
- initial_response = await ctx.send("📊 Synchronizing roles.")
- total_created, total_updated, total_deleted = await syncers.sync_roles(self.bot, ctx.guild)
- await initial_response.edit(
- content=(
- f"👌 Role synchronization complete, created **{total_created}** "
- f", updated **{total_created}** roles, and deleted **{total_deleted}** roles."
- )
- )
+ """Manually synchronise the guild's roles with the roles on the site."""
+ await self.role_syncer.sync(ctx.guild, ctx)
@sync_group.command(name='users')
@commands.has_permissions(administrator=True)
async def sync_users_command(self, ctx: Context) -> None:
- """Manually synchronize the guild's users with the users on the site."""
- initial_response = await ctx.send("📊 Synchronizing users.")
- total_created, total_updated, total_deleted = await syncers.sync_users(self.bot, ctx.guild)
- await initial_response.edit(
- content=(
- f"👌 User synchronization complete, created **{total_created}** "
- f"and updated **{total_created}** users."
- )
- )
+ """Manually synchronise the guild's users with the users on the site."""
+ await self.user_syncer.sync(ctx.guild, ctx)
diff --git a/bot/cogs/sync/syncers.py b/bot/cogs/sync/syncers.py
index 14cf51383..d6891168f 100644
--- a/bot/cogs/sync/syncers.py
+++ b/bot/cogs/sync/syncers.py
@@ -1,235 +1,342 @@
+import abc
+import logging
+import typing as t
from collections import namedtuple
-from typing import Dict, Set, Tuple
+from functools import partial
-from discord import Guild
+from discord import Guild, HTTPException, Member, Message, Reaction, User
+from discord.ext.commands import Context
+from bot import constants
+from bot.api import ResponseCodeError
from bot.bot import Bot
+log = logging.getLogger(__name__)
+
# These objects are declared as namedtuples because tuples are hashable,
# something that we make use of when diffing site roles against guild roles.
-Role = namedtuple('Role', ('id', 'name', 'colour', 'permissions', 'position'))
-User = namedtuple('User', ('id', 'name', 'discriminator', 'avatar_hash', 'roles', 'in_guild'))
-
-
-def get_roles_for_sync(
- guild_roles: Set[Role], api_roles: Set[Role]
-) -> Tuple[Set[Role], Set[Role], Set[Role]]:
- """
- Determine which roles should be created or updated on the site.
-
- Arguments:
- guild_roles (Set[Role]):
- Roles that were found on the guild at startup.
-
- api_roles (Set[Role]):
- Roles that were retrieved from the API at startup.
-
- Returns:
- Tuple[Set[Role], Set[Role]. Set[Role]]:
- A tuple with three elements. The first element represents
- roles to be created on the site, meaning that they were
- present on the cached guild but not on the API. The second
- element represents roles to be updated, meaning they were
- present on both the cached guild and the API but non-ID
- fields have changed inbetween. The third represents roles
- to be deleted on the site, meaning the roles are present on
- the API but not in the cached guild.
- """
- guild_role_ids = {role.id for role in guild_roles}
- api_role_ids = {role.id for role in api_roles}
- new_role_ids = guild_role_ids - api_role_ids
- deleted_role_ids = api_role_ids - guild_role_ids
-
- # New roles are those which are on the cached guild but not on the
- # API guild, going by the role ID. We need to send them in for creation.
- roles_to_create = {role for role in guild_roles if role.id in new_role_ids}
- roles_to_update = guild_roles - api_roles - roles_to_create
- roles_to_delete = {role for role in api_roles if role.id in deleted_role_ids}
- return roles_to_create, roles_to_update, roles_to_delete
-
-
-async def sync_roles(bot: Bot, guild: Guild) -> Tuple[int, int, int]:
- """
- Synchronize roles found on the given `guild` with the ones on the API.
-
- Arguments:
- bot (bot.bot.Bot):
- The bot instance that we're running with.
-
- guild (discord.Guild):
- The guild instance from the bot's cache
- to synchronize roles with.
-
- Returns:
- Tuple[int, int, int]:
- A tuple with three integers representing how many roles were created
- (element `0`) , how many roles were updated (element `1`), and how many
- roles were deleted (element `2`) on the API.
- """
- roles = await bot.api_client.get('bot/roles')
-
- # Pack API roles and guild roles into one common format,
- # which is also hashable. We need hashability to be able
- # to compare these easily later using sets.
- api_roles = {Role(**role_dict) for role_dict in roles}
- guild_roles = {
- Role(
- id=role.id, name=role.name,
- colour=role.colour.value, permissions=role.permissions.value,
- position=role.position,
- )
- for role in guild.roles
- }
- roles_to_create, roles_to_update, roles_to_delete = get_roles_for_sync(guild_roles, api_roles)
-
- for role in roles_to_create:
- await bot.api_client.post(
- 'bot/roles',
- json={
- 'id': role.id,
- 'name': role.name,
- 'colour': role.colour,
- 'permissions': role.permissions,
- 'position': role.position,
- }
- )
+_Role = namedtuple('Role', ('id', 'name', 'colour', 'permissions', 'position'))
+_User = namedtuple('User', ('id', 'name', 'discriminator', 'avatar_hash', 'roles', 'in_guild'))
+_Diff = namedtuple('Diff', ('created', 'updated', 'deleted'))
- for role in roles_to_update:
- await bot.api_client.put(
- f'bot/roles/{role.id}',
- json={
- 'id': role.id,
- 'name': role.name,
- 'colour': role.colour,
- 'permissions': role.permissions,
- 'position': role.position,
- }
- )
- for role in roles_to_delete:
- await bot.api_client.delete(f'bot/roles/{role.id}')
-
- return len(roles_to_create), len(roles_to_update), len(roles_to_delete)
-
-
-def get_users_for_sync(
- guild_users: Dict[int, User], api_users: Dict[int, User]
-) -> Tuple[Set[User], Set[User]]:
- """
- Determine which users should be created or updated on the website.
-
- Arguments:
- guild_users (Dict[int, User]):
- A mapping of user IDs to user data, populated from the
- guild cached on the running bot instance.
-
- api_users (Dict[int, User]):
- A mapping of user IDs to user data, populated from the API's
- current inventory of all users.
-
- Returns:
- Tuple[Set[User], Set[User]]:
- Two user sets as a tuple. The first element represents users
- to be created on the website, these are users that are present
- in the cached guild data but not in the API at all, going by
- their ID. The second element represents users to update. It is
- populated by users which are present on both the API and the
- guild, but where the attribute of a user on the API is not
- equal to the attribute of the user on the guild.
- """
- users_to_create = set()
- users_to_update = set()
-
- for api_user in api_users.values():
- guild_user = guild_users.get(api_user.id)
- if guild_user is not None:
- if api_user != guild_user:
- users_to_update.add(guild_user)
-
- elif api_user.in_guild:
- # The user is known on the API but not the guild, and the
- # API currently specifies that the user is a member of the guild.
- # This means that the user has left since the last sync.
- # Update the `in_guild` attribute of the user on the site
- # to signify that the user left.
- new_api_user = api_user._replace(in_guild=False)
- users_to_update.add(new_api_user)
-
- new_user_ids = set(guild_users.keys()) - set(api_users.keys())
- for user_id in new_user_ids:
- # The user is known on the guild but not on the API. This means
- # that the user has joined since the last sync. Create it.
- new_user = guild_users[user_id]
- users_to_create.add(new_user)
-
- return users_to_create, users_to_update
-
-
-async def sync_users(bot: Bot, guild: Guild) -> Tuple[int, int, None]:
- """
- Synchronize users found in the given `guild` with the ones in the API.
-
- Arguments:
- bot (bot.bot.Bot):
- The bot instance that we're running with.
-
- guild (discord.Guild):
- The guild instance from the bot's cache
- to synchronize roles with.
-
- Returns:
- Tuple[int, int, None]:
- A tuple with two integers, representing how many users were created
- (element `0`) and how many users were updated (element `1`), and `None`
- to indicate that a user sync never deletes entries from the API.
- """
- current_users = await bot.api_client.get('bot/users')
-
- # Pack API users and guild users into one common format,
- # which is also hashable. We need hashability to be able
- # to compare these easily later using sets.
- api_users = {
- user_dict['id']: User(
- roles=tuple(sorted(user_dict.pop('roles'))),
- **user_dict
- )
- for user_dict in current_users
- }
- guild_users = {
- member.id: User(
- id=member.id, name=member.name,
- discriminator=int(member.discriminator), avatar_hash=member.avatar,
- roles=tuple(sorted(role.id for role in member.roles)), in_guild=True
- )
- for member in guild.members
- }
-
- users_to_create, users_to_update = get_users_for_sync(guild_users, api_users)
-
- for user in users_to_create:
- await bot.api_client.post(
- 'bot/users',
- json={
- 'avatar_hash': user.avatar_hash,
- 'discriminator': user.discriminator,
- 'id': user.id,
- 'in_guild': user.in_guild,
- 'name': user.name,
- 'roles': list(user.roles)
- }
+class Syncer(abc.ABC):
+ """Base class for synchronising the database with objects in the Discord cache."""
+
+ _CORE_DEV_MENTION = f"<@&{constants.Roles.core_developers}> "
+ _REACTION_EMOJIS = (constants.Emojis.check_mark, constants.Emojis.cross_mark)
+
+ def __init__(self, bot: Bot) -> None:
+ self.bot = bot
+
+ @property
+ @abc.abstractmethod
+ def name(self) -> str:
+ """The name of the syncer; used in output messages and logging."""
+ raise NotImplementedError # pragma: no cover
+
+ async def _send_prompt(self, message: t.Optional[Message] = None) -> t.Optional[Message]:
+ """
+ Send a prompt to confirm or abort a sync using reactions and return the sent message.
+
+ If a message is given, it is edited to display the prompt and reactions. Otherwise, a new
+ message is sent to the dev-core channel and mentions the core developers role. If the
+ channel cannot be retrieved, return None.
+ """
+ log.trace(f"Sending {self.name} sync confirmation prompt.")
+
+ msg_content = (
+ f'Possible cache issue while syncing {self.name}s. '
+ f'More than {constants.Sync.max_diff} {self.name}s were changed. '
+ f'React to confirm or abort the sync.'
)
- for user in users_to_update:
- await bot.api_client.put(
- f'bot/users/{user.id}',
- json={
- 'avatar_hash': user.avatar_hash,
- 'discriminator': user.discriminator,
- 'id': user.id,
- 'in_guild': user.in_guild,
- 'name': user.name,
- 'roles': list(user.roles)
- }
+ # Send to core developers if it's an automatic sync.
+ if not message:
+ log.trace("Message not provided for confirmation; creating a new one in dev-core.")
+ channel = self.bot.get_channel(constants.Channels.dev_core)
+
+ if not channel:
+ log.debug("Failed to get the dev-core channel from cache; attempting to fetch it.")
+ try:
+ channel = await self.bot.fetch_channel(constants.Channels.dev_core)
+ except HTTPException:
+ log.exception(
+ f"Failed to fetch channel for sending sync confirmation prompt; "
+ f"aborting {self.name} sync."
+ )
+ return None
+
+ message = await channel.send(f"{self._CORE_DEV_MENTION}{msg_content}")
+ else:
+ await message.edit(content=msg_content)
+
+ # Add the initial reactions.
+ log.trace(f"Adding reactions to {self.name} syncer confirmation prompt.")
+ for emoji in self._REACTION_EMOJIS:
+ await message.add_reaction(emoji)
+
+ return message
+
+ def _reaction_check(
+ self,
+ author: Member,
+ message: Message,
+ reaction: Reaction,
+ user: t.Union[Member, User]
+ ) -> bool:
+ """
+ Return True if the `reaction` is a valid confirmation or abort reaction on `message`.
+
+ If the `author` of the prompt is a bot, then a reaction by any core developer will be
+ considered valid. Otherwise, the author of the reaction (`user`) will have to be the
+ `author` of the prompt.
+ """
+ # For automatic syncs, check for the core dev role instead of an exact author
+ has_role = any(constants.Roles.core_developers == role.id for role in user.roles)
+ return (
+ reaction.message.id == message.id
+ and not user.bot
+ and (has_role if author.bot else user == author)
+ and str(reaction.emoji) in self._REACTION_EMOJIS
)
- return len(users_to_create), len(users_to_update), None
+ async def _wait_for_confirmation(self, author: Member, message: Message) -> bool:
+ """
+ Wait for a confirmation reaction by `author` on `message` and return True if confirmed.
+
+ Uses the `_reaction_check` function to determine if a reaction is valid.
+
+ If there is no reaction within `bot.constants.Sync.confirm_timeout` seconds, return False.
+ To acknowledge the reaction (or lack thereof), `message` will be edited.
+ """
+ # Preserve the core-dev role mention in the message edits so users aren't confused about
+ # where notifications came from.
+ mention = self._CORE_DEV_MENTION if author.bot else ""
+
+ reaction = None
+ try:
+ log.trace(f"Waiting for a reaction to the {self.name} syncer confirmation prompt.")
+ reaction, _ = await self.bot.wait_for(
+ 'reaction_add',
+ check=partial(self._reaction_check, author, message),
+ timeout=constants.Sync.confirm_timeout
+ )
+ except TimeoutError:
+ # reaction will remain none thus sync will be aborted in the finally block below.
+ log.debug(f"The {self.name} syncer confirmation prompt timed out.")
+ finally:
+ if str(reaction) == constants.Emojis.check_mark:
+ log.trace(f"The {self.name} syncer was confirmed.")
+ await message.edit(content=f':ok_hand: {mention}{self.name} sync will proceed.')
+ return True
+ else:
+ log.warning(f"The {self.name} syncer was aborted or timed out!")
+ await message.edit(
+ content=f':warning: {mention}{self.name} sync aborted or timed out!'
+ )
+ return False
+
+ @abc.abstractmethod
+ async def _get_diff(self, guild: Guild) -> _Diff:
+ """Return the difference between the cache of `guild` and the database."""
+ raise NotImplementedError # pragma: no cover
+
+ @abc.abstractmethod
+ async def _sync(self, diff: _Diff) -> None:
+ """Perform the API calls for synchronisation."""
+ raise NotImplementedError # pragma: no cover
+
+ async def _get_confirmation_result(
+ self,
+ diff_size: int,
+ author: Member,
+ message: t.Optional[Message] = None
+ ) -> t.Tuple[bool, t.Optional[Message]]:
+ """
+ Prompt for confirmation and return a tuple of the result and the prompt message.
+
+ `diff_size` is the size of the diff of the sync. If it is greater than
+ `bot.constants.Sync.max_diff`, the prompt will be sent. The `author` is the invoked of the
+ sync and the `message` is an extant message to edit to display the prompt.
+
+ If confirmed or no confirmation was needed, the result is True. The returned message will
+ either be the given `message` or a new one which was created when sending the prompt.
+ """
+ log.trace(f"Determining if confirmation prompt should be sent for {self.name} syncer.")
+ if diff_size > constants.Sync.max_diff:
+ message = await self._send_prompt(message)
+ if not message:
+ return False, None # Couldn't get channel.
+
+ confirmed = await self._wait_for_confirmation(author, message)
+ if not confirmed:
+ return False, message # Sync aborted.
+
+ return True, message
+
+ async def sync(self, guild: Guild, ctx: t.Optional[Context] = None) -> None:
+ """
+ Synchronise the database with the cache of `guild`.
+
+ If the differences between the cache and the database are greater than
+ `bot.constants.Sync.max_diff`, then a confirmation prompt will be sent to the dev-core
+ channel. The confirmation can be optionally redirect to `ctx` instead.
+ """
+ log.info(f"Starting {self.name} syncer.")
+
+ message = None
+ author = self.bot.user
+ if ctx:
+ message = await ctx.send(f"📊 Synchronising {self.name}s.")
+ author = ctx.author
+
+ diff = await self._get_diff(guild)
+ diff_dict = diff._asdict() # Ugly method for transforming the NamedTuple into a dict
+ totals = {k: len(v) for k, v in diff_dict.items() if v is not None}
+ diff_size = sum(totals.values())
+
+ confirmed, message = await self._get_confirmation_result(diff_size, author, message)
+ if not confirmed:
+ return
+
+ # Preserve the core-dev role mention in the message edits so users aren't confused about
+ # where notifications came from.
+ mention = self._CORE_DEV_MENTION if author.bot else ""
+
+ try:
+ await self._sync(diff)
+ except ResponseCodeError as e:
+ log.exception(f"{self.name} syncer failed!")
+
+ # Don't show response text because it's probably some really long HTML.
+ results = f"status {e.status}\n```{e.response_json or 'See log output for details'}```"
+ content = f":x: {mention}Synchronisation of {self.name}s failed: {results}"
+ else:
+ results = ", ".join(f"{name} `{total}`" for name, total in totals.items())
+ log.info(f"{self.name} syncer finished: {results}.")
+ content = f":ok_hand: {mention}Synchronisation of {self.name}s complete: {results}"
+
+ if message:
+ await message.edit(content=content)
+
+
+class RoleSyncer(Syncer):
+ """Synchronise the database with roles in the cache."""
+
+ name = "role"
+
+ async def _get_diff(self, guild: Guild) -> _Diff:
+ """Return the difference of roles between the cache of `guild` and the database."""
+ log.trace("Getting the diff for roles.")
+ roles = await self.bot.api_client.get('bot/roles')
+
+ # Pack DB roles and guild roles into one common, hashable format.
+ # They're hashable so that they're easily comparable with sets later.
+ db_roles = {_Role(**role_dict) for role_dict in roles}
+ guild_roles = {
+ _Role(
+ id=role.id,
+ name=role.name,
+ colour=role.colour.value,
+ permissions=role.permissions.value,
+ position=role.position,
+ )
+ for role in guild.roles
+ }
+
+ guild_role_ids = {role.id for role in guild_roles}
+ api_role_ids = {role.id for role in db_roles}
+ new_role_ids = guild_role_ids - api_role_ids
+ deleted_role_ids = api_role_ids - guild_role_ids
+
+ # New roles are those which are on the cached guild but not on the
+ # DB guild, going by the role ID. We need to send them in for creation.
+ roles_to_create = {role for role in guild_roles if role.id in new_role_ids}
+ roles_to_update = guild_roles - db_roles - roles_to_create
+ roles_to_delete = {role for role in db_roles if role.id in deleted_role_ids}
+
+ return _Diff(roles_to_create, roles_to_update, roles_to_delete)
+
+ async def _sync(self, diff: _Diff) -> None:
+ """Synchronise the database with the role cache of `guild`."""
+ log.trace("Syncing created roles...")
+ for role in diff.created:
+ await self.bot.api_client.post('bot/roles', json=role._asdict())
+
+ log.trace("Syncing updated roles...")
+ for role in diff.updated:
+ await self.bot.api_client.put(f'bot/roles/{role.id}', json=role._asdict())
+
+ log.trace("Syncing deleted roles...")
+ for role in diff.deleted:
+ await self.bot.api_client.delete(f'bot/roles/{role.id}')
+
+
+class UserSyncer(Syncer):
+ """Synchronise the database with users in the cache."""
+
+ name = "user"
+
+ async def _get_diff(self, guild: Guild) -> _Diff:
+ """Return the difference of users between the cache of `guild` and the database."""
+ log.trace("Getting the diff for users.")
+ users = await self.bot.api_client.get('bot/users')
+
+ # Pack DB roles and guild roles into one common, hashable format.
+ # They're hashable so that they're easily comparable with sets later.
+ db_users = {
+ user_dict['id']: _User(
+ roles=tuple(sorted(user_dict.pop('roles'))),
+ **user_dict
+ )
+ for user_dict in users
+ }
+ guild_users = {
+ member.id: _User(
+ id=member.id,
+ name=member.name,
+ discriminator=int(member.discriminator),
+ avatar_hash=member.avatar,
+ roles=tuple(sorted(role.id for role in member.roles)),
+ in_guild=True
+ )
+ for member in guild.members
+ }
+
+ users_to_create = set()
+ users_to_update = set()
+
+ for db_user in db_users.values():
+ guild_user = guild_users.get(db_user.id)
+ if guild_user is not None:
+ if db_user != guild_user:
+ users_to_update.add(guild_user)
+
+ elif db_user.in_guild:
+ # The user is known in the DB but not the guild, and the
+ # DB currently specifies that the user is a member of the guild.
+ # This means that the user has left since the last sync.
+ # Update the `in_guild` attribute of the user on the site
+ # to signify that the user left.
+ new_api_user = db_user._replace(in_guild=False)
+ users_to_update.add(new_api_user)
+
+ new_user_ids = set(guild_users.keys()) - set(db_users.keys())
+ for user_id in new_user_ids:
+ # The user is known on the guild but not on the API. This means
+ # that the user has joined since the last sync. Create it.
+ new_user = guild_users[user_id]
+ users_to_create.add(new_user)
+
+ return _Diff(users_to_create, users_to_update, None)
+
+ async def _sync(self, diff: _Diff) -> None:
+ """Synchronise the database with the user cache of `guild`."""
+ log.trace("Syncing created users...")
+ for user in diff.created:
+ await self.bot.api_client.post('bot/users', json=user._asdict())
+
+ log.trace("Syncing updated users...")
+ for user in diff.updated:
+ await self.bot.api_client.put(f'bot/users/{user.id}', json=user._asdict())
diff --git a/bot/cogs/tags.py b/bot/cogs/tags.py
index 54a51921c..5da9a4148 100644
--- a/bot/cogs/tags.py
+++ b/bot/cogs/tags.py
@@ -15,8 +15,7 @@ from bot.pagination import LinePaginator
log = logging.getLogger(__name__)
TEST_CHANNELS = (
- Channels.devtest,
- Channels.bot,
+ Channels.bot_commands,
Channels.helpers
)
@@ -116,8 +115,10 @@ class Tags(Cog):
if _command_on_cooldown(tag_name):
time_left = Cooldowns.tags - (time.time() - self.tag_cooldowns[tag_name]["time"])
- log.warning(f"{ctx.author} tried to get the '{tag_name}' tag, but the tag is on cooldown. "
- f"Cooldown ends in {time_left:.1f} seconds.")
+ log.info(
+ f"{ctx.author} tried to get the '{tag_name}' tag, but the tag is on cooldown. "
+ f"Cooldown ends in {time_left:.1f} seconds."
+ )
return
await self._get_tags()
@@ -219,7 +220,7 @@ class Tags(Cog):
))
@tags_group.command(name='delete', aliases=('remove', 'rm', 'd'))
- @with_role(Roles.admin, Roles.owner)
+ @with_role(Roles.admins, Roles.owners)
async def delete_command(self, ctx: Context, *, tag_name: TagNameConverter) -> None:
"""Remove a tag from the database."""
await self.bot.api_client.delete(f'bot/tags/{tag_name}')
diff --git a/bot/cogs/utils.py b/bot/cogs/utils.py
index da278011a..94b9d6b5a 100644
--- a/bot/cogs/utils.py
+++ b/bot/cogs/utils.py
@@ -89,7 +89,7 @@ class Utils(Cog):
await ctx.message.channel.send(embed=pep_embed)
@command()
- @in_channel(Channels.bot, bypass_roles=STAFF_ROLES)
+ @in_channel(Channels.bot_commands, bypass_roles=STAFF_ROLES)
async def charinfo(self, ctx: Context, *, characters: str) -> None:
"""Shows you information on up to 25 unicode characters."""
match = re.match(r"<(a?):(\w+):(\d+)>", characters)
diff --git a/bot/cogs/verification.py b/bot/cogs/verification.py
index 988e0d49a..57b50c34f 100644
--- a/bot/cogs/verification.py
+++ b/bot/cogs/verification.py
@@ -1,7 +1,8 @@
import logging
+from contextlib import suppress
from datetime import datetime
-from discord import Colour, Message, NotFound, Object
+from discord import Colour, Forbidden, Message, NotFound, Object
from discord.ext import tasks
from discord.ext.commands import Cog, Context, command
@@ -29,15 +30,16 @@ your information removed here as well.
Feel free to review them at any point!
Additionally, if you'd like to receive notifications for the announcements we post in <#{Channels.announcements}> \
-from time to time, you can send `!subscribe` to <#{Channels.bot}> at any time to assign yourself the \
+from time to time, you can send `!subscribe` to <#{Channels.bot_commands}> at any time to assign yourself the \
**Announcements** role. We'll mention this role every time we make an announcement.
-If you'd like to unsubscribe from the announcement notifications, simply send `!unsubscribe` to <#{Channels.bot}>.
+If you'd like to unsubscribe from the announcement notifications, simply send `!unsubscribe` to \
+<#{Channels.bot_commands}>.
"""
PERIODIC_PING = (
f"@everyone To verify that you have read our rules, please type `{BotConfig.prefix}accept`."
- f" If you encounter any problems during the verification process, ping the <@&{Roles.admin}> role in this channel."
+ f" If you encounter any problems during the verification process, ping the <@&{Roles.admins}> role in this channel."
)
BOT_MESSAGE_DELETE_DELAY = 10
@@ -92,19 +94,21 @@ class Verification(Cog):
ping_everyone=Filter.ping_everyone,
)
- ctx = await self.bot.get_context(message) # type: Context
-
+ ctx: Context = await self.bot.get_context(message)
if ctx.command is not None and ctx.command.name == "accept":
- return # They used the accept command
+ return
- for role in ctx.author.roles:
- if role.id == Roles.verified:
- log.warning(f"{ctx.author} posted '{ctx.message.content}' "
- "in the verification channel, but is already verified.")
- return # They're already verified
+ if any(r.id == Roles.verified for r in ctx.author.roles):
+ log.info(
+ f"{ctx.author} posted '{ctx.message.content}' "
+ "in the verification channel, but is already verified."
+ )
+ return
- log.debug(f"{ctx.author} posted '{ctx.message.content}' in the verification "
- "channel. We are providing instructions how to verify.")
+ log.debug(
+ f"{ctx.author} posted '{ctx.message.content}' in the verification "
+ "channel. We are providing instructions how to verify."
+ )
await ctx.send(
f"{ctx.author.mention} Please type `!accept` to verify that you accept our rules, "
f"and gain access to the rest of the server.",
@@ -112,11 +116,8 @@ class Verification(Cog):
)
log.trace(f"Deleting the message posted by {ctx.author}")
-
- try:
+ with suppress(NotFound):
await ctx.message.delete()
- except NotFound:
- log.trace("No message found, it must have been deleted by another bot.")
@command(name='accept', aliases=('verify', 'verified', 'accepted'), hidden=True)
@without_role(Roles.verified)
@@ -127,20 +128,16 @@ class Verification(Cog):
await ctx.author.add_roles(Object(Roles.verified), reason="Accepted the rules")
try:
await ctx.author.send(WELCOME_MESSAGE)
- except Exception:
- # Catch the exception, in case they have DMs off or something
- log.exception(f"Unable to send welcome message to user {ctx.author}.")
-
- log.trace(f"Deleting the message posted by {ctx.author}.")
-
- try:
- self.mod_log.ignore(Event.message_delete, ctx.message.id)
- await ctx.message.delete()
- except NotFound:
- log.trace("No message found, it must have been deleted by another bot.")
+ except Forbidden:
+ log.info(f"Sending welcome message failed for {ctx.author}.")
+ finally:
+ log.trace(f"Deleting accept message by {ctx.author}.")
+ with suppress(NotFound):
+ self.mod_log.ignore(Event.message_delete, ctx.message.id)
+ await ctx.message.delete()
@command(name='subscribe')
- @in_channel(Channels.bot)
+ @in_channel(Channels.bot_commands)
async def subscribe_command(self, ctx: Context, *_) -> None: # We don't actually care about the args
"""Subscribe to announcement notifications by assigning yourself the role."""
has_role = False
@@ -164,7 +161,7 @@ class Verification(Cog):
)
@command(name='unsubscribe')
- @in_channel(Channels.bot)
+ @in_channel(Channels.bot_commands)
async def unsubscribe_command(self, ctx: Context, *_) -> None: # We don't actually care about the args
"""Unsubscribe from announcement notifications by removing the role from yourself."""
has_role = False
@@ -223,7 +220,7 @@ class Verification(Cog):
@periodic_ping.before_loop
async def before_ping(self) -> None:
"""Only start the loop when the bot is ready."""
- await self.bot.wait_until_ready()
+ await self.bot.wait_until_guild_available()
def cog_unload(self) -> None:
"""Cancel the periodic ping task when the cog is unloaded."""
diff --git a/bot/cogs/watchchannels/watchchannel.py b/bot/cogs/watchchannels/watchchannel.py
index eb787b083..3667a80e8 100644
--- a/bot/cogs/watchchannels/watchchannel.py
+++ b/bot/cogs/watchchannels/watchchannel.py
@@ -91,7 +91,7 @@ class WatchChannel(metaclass=CogABCMeta):
async def start_watchchannel(self) -> None:
"""Starts the watch channel by getting the channel, webhook, and user cache ready."""
- await self.bot.wait_until_ready()
+ await self.bot.wait_until_guild_available()
try:
self.channel = await self.bot.fetch_channel(self.destination)
diff --git a/bot/constants.py b/bot/constants.py
index fe8e57322..14f8dc094 100644
--- a/bot/constants.py
+++ b/bot/constants.py
@@ -186,6 +186,11 @@ class YAMLGetter(type):
def __getitem__(cls, name):
return cls.__getattr__(name)
+ def __iter__(cls):
+ """Return generator of key: value pairs of current constants class' config values."""
+ for name in cls.__annotations__:
+ yield name, getattr(cls, name)
+
# Dataclasses
class Bot(metaclass=YAMLGetter):
@@ -193,7 +198,7 @@ class Bot(metaclass=YAMLGetter):
prefix: str
token: str
-
+ sentry_dsn: str
class Filter(metaclass=YAMLGetter):
section = "filter"
@@ -263,6 +268,7 @@ class Emojis(metaclass=YAMLGetter):
new: str
pencil: str
cross_mark: str
+ check_mark: str
ducky_yellow: int
ducky_blurple: int
@@ -357,16 +363,16 @@ class Channels(metaclass=YAMLGetter):
section = "guild"
subsection = "channels"
- admins: int
admin_spam: int
+ admins: int
announcements: int
attachment_log: int
big_brother_logs: int
- bot: int
- checkpoint_test: int
+ bot_commands: int
defcon: int
- devlog: int
- devtest: int
+ dev_contrib: int
+ dev_core: int
+ dev_log: int
esoteric: int
help_0: int
help_1: int
@@ -379,19 +385,19 @@ class Channels(metaclass=YAMLGetter):
helpers: int
message_log: int
meta: int
+ mod_alerts: int
+ mod_log: int
mod_spam: int
mods: int
- mod_alerts: int
- modlog: int
off_topic_0: int
off_topic_1: int
off_topic_2: int
organisation: int
- python: int
+ python_discussion: int
reddit: int
talent_pool: int
- userlog: int
- user_event_a: int
+ user_event_announcements: int
+ user_log: int
verification: int
voice_log: int
@@ -404,25 +410,25 @@ class Webhooks(metaclass=YAMLGetter):
big_brother: int
reddit: int
duck_pond: int
+ dev_log: int
class Roles(metaclass=YAMLGetter):
section = "guild"
subsection = "roles"
- admin: int
+ admins: int
announcements: int
- champion: int
- contributor: int
- core_developer: int
+ contributors: int
+ core_developers: int
helpers: int
- jammer: int
- moderator: int
+ jammers: int
+ moderators: int
muted: int
- owner: int
+ owners: int
partners: int
- rockstars: int
- team_leader: int
+ python_community: int
+ team_leaders: int
verified: int # This is the Developers role on PyDis, here named verified for readability reasons.
@@ -430,9 +436,12 @@ class Guild(metaclass=YAMLGetter):
section = "guild"
id: int
- ignored: List[int]
+ moderation_channels: List[int]
+ moderation_roles: List[int]
+ modlog_blacklist: List[int]
+ reminder_whitelist: List[int]
staff_channels: List[int]
-
+ staff_roles: List[int]
class Keys(metaclass=YAMLGetter):
section = "keys"
@@ -537,6 +546,13 @@ class RedirectOutput(metaclass=YAMLGetter):
delete_delay: int
+class Sync(metaclass=YAMLGetter):
+ section = 'sync'
+
+ confirm_timeout: int
+ max_diff: int
+
+
class Event(Enum):
"""
Event names. This does not include every event (for example, raw
@@ -571,14 +587,14 @@ BOT_DIR = os.path.dirname(__file__)
PROJECT_ROOT = os.path.abspath(os.path.join(BOT_DIR, os.pardir))
# Default role combinations
-MODERATION_ROLES = Roles.moderator, Roles.admin, Roles.owner
-STAFF_ROLES = Roles.helpers, Roles.moderator, Roles.admin, Roles.owner
+MODERATION_ROLES = Guild.moderation_roles
+STAFF_ROLES = Guild.staff_roles
# Roles combinations
STAFF_CHANNELS = Guild.staff_channels
# Default Channel combinations
-MODERATION_CHANNELS = Channels.admins, Channels.admin_spam, Channels.mod_alerts, Channels.mods, Channels.mod_spam
+MODERATION_CHANNELS = Guild.moderation_channels
# Bot replies
diff --git a/bot/pagination.py b/bot/pagination.py
index e82763912..90c8f849c 100644
--- a/bot/pagination.py
+++ b/bot/pagination.py
@@ -1,8 +1,9 @@
import asyncio
import logging
-from typing import Iterable, List, Optional, Tuple
+import typing as t
+from contextlib import suppress
-from discord import Embed, Member, Message, Reaction
+import discord
from discord.abc import User
from discord.ext.commands import Context, Paginator
@@ -14,7 +15,7 @@ RIGHT_EMOJI = "\u27A1" # [:arrow_right:]
LAST_EMOJI = "\u23ED" # [:track_next:]
DELETE_EMOJI = constants.Emojis.trashcan # [:trashcan:]
-PAGINATION_EMOJI = [FIRST_EMOJI, LEFT_EMOJI, RIGHT_EMOJI, LAST_EMOJI, DELETE_EMOJI]
+PAGINATION_EMOJI = (FIRST_EMOJI, LEFT_EMOJI, RIGHT_EMOJI, LAST_EMOJI, DELETE_EMOJI)
log = logging.getLogger(__name__)
@@ -89,12 +90,12 @@ class LinePaginator(Paginator):
@classmethod
async def paginate(
cls,
- lines: Iterable[str],
+ lines: t.List[str],
ctx: Context,
- embed: Embed,
+ embed: discord.Embed,
prefix: str = "",
suffix: str = "",
- max_lines: Optional[int] = None,
+ max_lines: t.Optional[int] = None,
max_size: int = 500,
empty: bool = True,
restrict_to_user: User = None,
@@ -102,7 +103,7 @@ class LinePaginator(Paginator):
footer_text: str = None,
url: str = None,
exception_on_empty_embed: bool = False
- ) -> Optional[Message]:
+ ) -> t.Optional[discord.Message]:
"""
Use a paginator and set of reactions to provide pagination over a set of lines.
@@ -114,11 +115,11 @@ class LinePaginator(Paginator):
Pagination will also be removed automatically if no reaction is added for five minutes (300 seconds).
Example:
- >>> embed = Embed()
+ >>> embed = discord.Embed()
>>> embed.set_author(name="Some Operation", url=url, icon_url=icon)
- >>> await LinePaginator.paginate((line for line in lines), ctx, embed)
+ >>> await LinePaginator.paginate([line for line in lines], ctx, embed)
"""
- def event_check(reaction_: Reaction, user_: Member) -> bool:
+ def event_check(reaction_: discord.Reaction, user_: discord.Member) -> bool:
"""Make sure that this reaction is what we want to operate on."""
no_restrictions = (
# Pagination is not restricted
@@ -281,8 +282,9 @@ class LinePaginator(Paginator):
await message.edit(embed=embed)
- log.debug("Ending pagination and removing all reactions...")
- await message.clear_reactions()
+ log.debug("Ending pagination and clearing reactions.")
+ with suppress(discord.NotFound):
+ await message.clear_reactions()
class ImagePaginator(Paginator):
@@ -299,6 +301,7 @@ class ImagePaginator(Paginator):
self._current_page = [prefix]
self.images = []
self._pages = []
+ self._count = 0
def add_line(self, line: str = '', *, empty: bool = False) -> None:
"""Adds a line to each page."""
@@ -316,13 +319,13 @@ class ImagePaginator(Paginator):
@classmethod
async def paginate(
cls,
- pages: List[Tuple[str, str]],
- ctx: Context, embed: Embed,
+ pages: t.List[t.Tuple[str, str]],
+ ctx: Context, embed: discord.Embed,
prefix: str = "",
suffix: str = "",
timeout: int = 300,
exception_on_empty_embed: bool = False
- ) -> Optional[Message]:
+ ) -> t.Optional[discord.Message]:
"""
Use a paginator and set of reactions to provide pagination over a set of title/image pairs.
@@ -334,11 +337,11 @@ class ImagePaginator(Paginator):
Note: Pagination will be removed automatically if no reaction is added for five minutes (300 seconds).
Example:
- >>> embed = Embed()
+ >>> embed = discord.Embed()
>>> embed.set_author(name="Some Operation", url=url, icon_url=icon)
>>> await ImagePaginator.paginate(pages, ctx, embed)
"""
- def check_event(reaction_: Reaction, member: Member) -> bool:
+ def check_event(reaction_: discord.Reaction, member: discord.Member) -> bool:
"""Checks each reaction added, if it matches our conditions pass the wait_for."""
return all((
# Reaction is on the same message sent
@@ -445,5 +448,6 @@ class ImagePaginator(Paginator):
await message.edit(embed=embed)
- log.debug("Ending pagination and removing all reactions...")
- await message.clear_reactions()
+ log.debug("Ending pagination and clearing reactions.")
+ with suppress(discord.NotFound):
+ await message.clear_reactions()
diff --git a/bot/rules/attachments.py b/bot/rules/attachments.py
index 00bb2a949..8903c385c 100644
--- a/bot/rules/attachments.py
+++ b/bot/rules/attachments.py
@@ -19,7 +19,7 @@ async def apply(
if total_recent_attachments > config['max']:
return (
- f"sent {total_recent_attachments} attachments in {config['max']}s",
+ f"sent {total_recent_attachments} attachments in {config['interval']}s",
(last_message.author,),
relevant_messages
)
diff --git a/bot/utils/__init__.py b/bot/utils/__init__.py
index 8184be824..3e4b15ce4 100644
--- a/bot/utils/__init__.py
+++ b/bot/utils/__init__.py
@@ -1,5 +1,5 @@
from abc import ABCMeta
-from typing import Any, Generator, Hashable, Iterable
+from typing import Any, Hashable
from discord.ext.commands import CogMeta
@@ -64,13 +64,3 @@ class CaseInsensitiveDict(dict):
for k in list(self.keys()):
v = super(CaseInsensitiveDict, self).pop(k)
self.__setitem__(k, v)
-
-
-def chunks(iterable: Iterable, size: int) -> Generator[Any, None, None]:
- """
- Generator that allows you to iterate over any indexable collection in `size`-length chunks.
-
- Found: https://stackoverflow.com/a/312464/4022104
- """
- for i in range(0, len(iterable), size):
- yield iterable[i:i + size]
diff --git a/bot/utils/time.py b/bot/utils/time.py
index 7416f36e0..77060143c 100644
--- a/bot/utils/time.py
+++ b/bot/utils/time.py
@@ -114,30 +114,40 @@ def format_infraction(timestamp: str) -> str:
def format_infraction_with_duration(
- expiry: Optional[str],
+ date_to: Optional[str],
date_from: Optional[datetime.datetime] = None,
- max_units: int = 2
+ max_units: int = 2,
+ absolute: bool = True
) -> Optional[str]:
"""
- Format an infraction timestamp to a more readable ISO 8601 format WITH the duration.
+ Return `date_to` formatted as a readable ISO-8601 with the humanized duration since `date_from`.
- Returns a human-readable version of the duration between datetime.utcnow() and an expiry.
- Unlike `humanize_delta`, this function will force the `precision` to be `seconds` by not passing it.
- `max_units` specifies the maximum number of units of time to include (e.g. 1 may include days but not hours).
- By default, max_units is 2.
+ `date_from` must be an ISO-8601 formatted timestamp. The duration is calculated as from
+ `date_from` until `date_to` with a precision of seconds. If `date_from` is unspecified, the
+ current time is used.
+
+ `max_units` specifies the maximum number of units of time to include in the duration. For
+ example, a value of 1 may include days but not hours.
+
+ If `absolute` is True, the absolute value of the duration delta is used. This prevents negative
+ values in the case that `date_to` is in the past relative to `date_from`.
"""
- if not expiry:
+ if not date_to:
return None
+ date_to_formatted = format_infraction(date_to)
+
date_from = date_from or datetime.datetime.utcnow()
- date_to = dateutil.parser.isoparse(expiry).replace(tzinfo=None, microsecond=0)
+ date_to = dateutil.parser.isoparse(date_to).replace(tzinfo=None, microsecond=0)
- expiry_formatted = format_infraction(expiry)
+ delta = relativedelta(date_to, date_from)
+ if absolute:
+ delta = abs(delta)
- duration = humanize_delta(relativedelta(date_to, date_from), max_units=max_units)
- duration_formatted = f" ({duration})" if duration else ''
+ duration = humanize_delta(delta, max_units=max_units)
+ duration_formatted = f" ({duration})" if duration else ""
- return f"{expiry_formatted}{duration_formatted}"
+ return f"{date_to_formatted}{duration_formatted}"
def until_expiration(
diff --git a/config-default.yml b/config-default.yml
index 3345e6f2a..ab237423f 100644
--- a/config-default.yml
+++ b/config-default.yml
@@ -1,6 +1,7 @@
bot:
prefix: "!"
token: !ENV "BOT_TOKEN"
+ sentry_dsn: !ENV "BOT_SENTRY_DSN"
cooldowns:
# Per channel, per tag.
@@ -34,6 +35,7 @@ style:
pencil: "\u270F"
new: "\U0001F195"
cross_mark: "\u274C"
+ check_mark: "\u2705"
ducky_yellow: &DUCKY_YELLOW 574951975574175744
ducky_blurple: &DUCKY_BLURPLE 574951975310065675
@@ -109,74 +111,135 @@ guild:
id: 267624335836053506
categories:
- python_help: 356013061213126657
+ python_help: 356013061213126657
channels:
- admins: &ADMINS 365960823622991872
- admin_spam: &ADMIN_SPAM 563594791770914816
- admins_voice: &ADMINS_VOICE 500734494840717332
- announcements: 354619224620138496
- attachment_log: &ATTCH_LOG 649243850006855680
- big_brother_logs: &BBLOGS 468507907357409333
- bot: 267659945086812160
- checkpoint_test: 422077681434099723
- defcon: &DEFCON 464469101889454091
- devlog: &DEVLOG 622895325144940554
- devtest: &DEVTEST 414574275865870337
- esoteric: 470884583684964352
- help_0: 303906576991780866
- help_1: 303906556754395136
- help_2: 303906514266226689
- help_3: 439702951246692352
- help_4: 451312046647148554
- help_5: 454941769734422538
- help_6: 587375753306570782
- help_7: 587375768556797982
- helpers: &HELPERS 385474242440986624
- message_log: &MESSAGE_LOG 467752170159079424
- meta: 429409067623251969
- mod_spam: &MOD_SPAM 620607373828030464
- mods: &MODS 305126844661760000
- mod_alerts: 473092532147060736
- modlog: &MODLOG 282638479504965634
- off_topic_0: 291284109232308226
- off_topic_1: 463035241142026251
- off_topic_2: 463035268514185226
- organisation: &ORGANISATION 551789653284356126
- python: 267624335836053506
- reddit: 458224812528238616
- staff_lounge: &STAFF_LOUNGE 464905259261755392
- staff_voice: &STAFF_VOICE 412375055910043655
- talent_pool: &TALENT_POOL 534321732593647616
- userlog: 528976905546760203
- user_event_a: &USER_EVENT_A 592000283102674944
- verification: 352442727016693763
- voice_log: 640292421988646961
-
- staff_channels: [*ADMINS, *ADMIN_SPAM, *MOD_SPAM, *MODS, *HELPERS, *ORGANISATION, *DEFCON]
- ignored: [*ADMINS, *MESSAGE_LOG, *MODLOG, *ADMINS_VOICE, *STAFF_VOICE, *ATTCH_LOG]
+ announcements: 354619224620138496
+ user_event_announcements: &USER_EVENT_A 592000283102674944
+
+ # Development
+ dev_contrib: &DEV_CONTRIB 635950537262759947
+ dev_core: &DEV_CORE 411200599653351425
+ dev_log: &DEV_LOG 622895325144940554
+
+ # Discussion
+ meta: 429409067623251969
+ python_discussion: 267624335836053506
+
+ # Logs
+ attachment_log: &ATTACH_LOG 649243850006855680
+ message_log: &MESSAGE_LOG 467752170159079424
+ mod_log: &MOD_LOG 282638479504965634
+ user_log: 528976905546760203
+ voice_log: 640292421988646961
+
+ # Off-topic
+ off_topic_0: 291284109232308226
+ off_topic_1: 463035241142026251
+ off_topic_2: 463035268514185226
+
+ # Python Help
+ help_0: 303906576991780866
+ help_1: 303906556754395136
+ help_2: 303906514266226689
+ help_3: 439702951246692352
+ help_4: 451312046647148554
+ help_5: 454941769734422538
+ help_6: 587375753306570782
+ help_7: 587375768556797982
+
+ # Special
+ bot_commands: &BOT_CMD 267659945086812160
+ esoteric: 470884583684964352
+ reddit: 458224812528238616
+ verification: 352442727016693763
+
+ # Staff
+ admins: &ADMINS 365960823622991872
+ admin_spam: &ADMIN_SPAM 563594791770914816
+ defcon: &DEFCON 464469101889454091
+ helpers: &HELPERS 385474242440986624
+ mods: &MODS 305126844661760000
+ mod_alerts: &MOD_ALERTS 473092532147060736
+ mod_spam: &MOD_SPAM 620607373828030464
+ organisation: &ORGANISATION 551789653284356126
+ staff_lounge: &STAFF_LOUNGE 464905259261755392
+
+ # Voice
+ admins_voice: &ADMINS_VOICE 500734494840717332
+ staff_voice: &STAFF_VOICE 412375055910043655
+
+ # Watch
+ big_brother_logs: &BB_LOGS 468507907357409333
+ talent_pool: &TALENT_POOL 534321732593647616
+
+ staff_channels:
+ - *ADMINS
+ - *ADMIN_SPAM
+ - *DEFCON
+ - *HELPERS
+ - *MODS
+ - *MOD_SPAM
+ - *ORGANISATION
+
+ moderation_channels:
+ - *ADMINS
+ - *ADMIN_SPAM
+ - *MOD_ALERTS
+ - *MODS
+ - *MOD_SPAM
+
+ # Modlog cog ignores events which occur in these channels
+ modlog_blacklist:
+ - *ADMINS
+ - *ADMINS_VOICE
+ - *ATTACH_LOG
+ - *MESSAGE_LOG
+ - *MOD_LOG
+ - *STAFF_VOICE
+
+ reminder_whitelist:
+ - *BOT_CMD
+ - *DEV_CONTRIB
roles:
- admin: &ADMIN_ROLE 267628507062992896
- announcements: 463658397560995840
- champion: 430492892331769857
- contributor: 295488872404484098
- core_developer: 587606783669829632
- helpers: 267630620367257601
- jammer: 591786436651646989
- moderator: &MOD_ROLE 267629731250176001
- muted: &MUTED_ROLE 277914926603829249
- owner: &OWNER_ROLE 267627879762755584
- partners: 323426753857191936
- rockstars: &ROCKSTARS_ROLE 458226413825294336
- team_leader: 501324292341104650
- verified: 352427296948486144
+ announcements: 463658397560995840
+ contributors: 295488872404484098
+ muted: &MUTED_ROLE 277914926603829249
+ partners: 323426753857191936
+ python_community: &PY_COMMUNITY_ROLE 458226413825294336
+
+ # This is the Developers role on PyDis, here named verified for readability reasons
+ verified: 352427296948486144
+
+ # Staff
+ admins: &ADMINS_ROLE 267628507062992896
+ core_developers: 587606783669829632
+ helpers: &HELPERS_ROLE 267630620367257601
+ moderators: &MODS_ROLE 267629731250176001
+ owners: &OWNERS_ROLE 267627879762755584
+
+ # Code Jam
+ jammers: 591786436651646989
+ team_leaders: 501324292341104650
+
+ moderation_roles:
+ - *OWNERS_ROLE
+ - *ADMINS_ROLE
+ - *MODS_ROLE
+
+ staff_roles:
+ - *OWNERS_ROLE
+ - *ADMINS_ROLE
+ - *MODS_ROLE
+ - *HELPERS_ROLE
webhooks:
- talent_pool: 569145364800602132
- big_brother: 569133704568373283
- reddit: 635408384794951680
- duck_pond: 637821475327311927
+ talent_pool: 569145364800602132
+ big_brother: 569133704568373283
+ reddit: 635408384794951680
+ duck_pond: 637821475327311927
+ dev_log: 680501655111729222
filter:
@@ -216,6 +279,7 @@ filter:
- 438622377094414346 # Pyglet
- 524691714909274162 # Panda3D
- 336642139381301249 # discord.py
+ - 405403391410438165 # Sentdex
domain_blacklist:
- pornhub.com
@@ -253,20 +317,19 @@ filter:
# Censor doesn't apply to these
channel_whitelist:
- *ADMINS
- - *MODLOG
+ - *MOD_LOG
- *MESSAGE_LOG
- - *DEVLOG
- - *BBLOGS
+ - *DEV_LOG
+ - *BB_LOGS
- *STAFF_LOUNGE
- - *DEVTEST
- *TALENT_POOL
- *USER_EVENT_A
role_whitelist:
- - *ADMIN_ROLE
- - *MOD_ROLE
- - *OWNER_ROLE
- - *ROCKSTARS_ROLE
+ - *ADMINS_ROLE
+ - *MODS_ROLE
+ - *OWNERS_ROLE
+ - *PY_COMMUNITY_ROLE
keys:
@@ -428,9 +491,26 @@ redirect_output:
delete_invocation: true
delete_delay: 15
+sync:
+ confirm_timeout: 300
+ max_diff: 10
+
duck_pond:
threshold: 5
- custom_emojis: [*DUCKY_YELLOW, *DUCKY_BLURPLE, *DUCKY_CAMO, *DUCKY_DEVIL, *DUCKY_NINJA, *DUCKY_REGAL, *DUCKY_TUBE, *DUCKY_HUNT, *DUCKY_WIZARD, *DUCKY_PARTY, *DUCKY_ANGEL, *DUCKY_MAUL, *DUCKY_SANTA]
+ custom_emojis:
+ - *DUCKY_YELLOW
+ - *DUCKY_BLURPLE
+ - *DUCKY_CAMO
+ - *DUCKY_DEVIL
+ - *DUCKY_NINJA
+ - *DUCKY_REGAL
+ - *DUCKY_TUBE
+ - *DUCKY_HUNT
+ - *DUCKY_WIZARD
+ - *DUCKY_PARTY
+ - *DUCKY_ANGEL
+ - *DUCKY_MAUL
+ - *DUCKY_SANTA
config:
required_keys: ['bot.token']
diff --git a/docker-compose.yml b/docker-compose.yml
index 7281c7953..11deceae8 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -23,6 +23,7 @@ services:
- staff.web
ports:
- "127.0.0.1:8000:8000"
+ tty: true
depends_on:
- postgres
environment:
@@ -37,6 +38,7 @@ services:
volumes:
- ./logs:/bot/logs
- .:/bot:ro
+ tty: true
depends_on:
- web
environment:
diff --git a/tests/base.py b/tests/base.py
index 029a249ed..88693f382 100644
--- a/tests/base.py
+++ b/tests/base.py
@@ -1,6 +1,12 @@
import logging
import unittest
from contextlib import contextmanager
+from typing import Dict
+
+import discord
+from discord.ext import commands
+
+from tests import helpers
class _CaptureLogHandler(logging.Handler):
@@ -65,3 +71,31 @@ class LoggingTestCase(unittest.TestCase):
standard_message = self._truncateMessage(base_message, record_message)
msg = self._formatMessage(msg, standard_message)
self.fail(msg)
+
+
+class CommandTestCase(unittest.TestCase):
+ """TestCase with additional assertions that are useful for testing Discord commands."""
+
+ @helpers.async_test
+ async def assertHasPermissionsCheck(
+ self,
+ cmd: commands.Command,
+ permissions: Dict[str, bool],
+ ) -> None:
+ """
+ Test that `cmd` raises a `MissingPermissions` exception if author lacks `permissions`.
+
+ Every permission in `permissions` is expected to be reported as missing. In other words, do
+ not include permissions which should not raise an exception along with those which should.
+ """
+ # Invert permission values because it's more intuitive to pass to this assertion the same
+ # permissions as those given to the check decorator.
+ permissions = {k: not v for k, v in permissions.items()}
+
+ ctx = helpers.MockContext()
+ ctx.channel.permissions_for.return_value = discord.Permissions(**permissions)
+
+ with self.assertRaises(commands.MissingPermissions) as cm:
+ await cmd.can_run(ctx)
+
+ self.assertCountEqual(permissions.keys(), cm.exception.missing_perms)
diff --git a/tests/bot/cogs/sync/test_base.py b/tests/bot/cogs/sync/test_base.py
new file mode 100644
index 000000000..c2e143865
--- /dev/null
+++ b/tests/bot/cogs/sync/test_base.py
@@ -0,0 +1,412 @@
+import unittest
+from unittest import mock
+
+import discord
+
+from bot import constants
+from bot.api import ResponseCodeError
+from bot.cogs.sync.syncers import Syncer, _Diff
+from tests import helpers
+
+
+class TestSyncer(Syncer):
+ """Syncer subclass with mocks for abstract methods for testing purposes."""
+
+ name = "test"
+ _get_diff = helpers.AsyncMock()
+ _sync = helpers.AsyncMock()
+
+
+class SyncerBaseTests(unittest.TestCase):
+ """Tests for the syncer base class."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot()
+
+ def test_instantiation_fails_without_abstract_methods(self):
+ """The class must have abstract methods implemented."""
+ with self.assertRaisesRegex(TypeError, "Can't instantiate abstract class"):
+ Syncer(self.bot)
+
+
+class SyncerSendPromptTests(unittest.TestCase):
+ """Tests for sending the sync confirmation prompt."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot()
+ self.syncer = TestSyncer(self.bot)
+
+ def mock_get_channel(self):
+ """Fixture to return a mock channel and message for when `get_channel` is used."""
+ self.bot.reset_mock()
+
+ mock_channel = helpers.MockTextChannel()
+ mock_message = helpers.MockMessage()
+
+ mock_channel.send.return_value = mock_message
+ self.bot.get_channel.return_value = mock_channel
+
+ return mock_channel, mock_message
+
+ def mock_fetch_channel(self):
+ """Fixture to return a mock channel and message for when `fetch_channel` is used."""
+ self.bot.reset_mock()
+
+ mock_channel = helpers.MockTextChannel()
+ mock_message = helpers.MockMessage()
+
+ self.bot.get_channel.return_value = None
+ mock_channel.send.return_value = mock_message
+ self.bot.fetch_channel.return_value = mock_channel
+
+ return mock_channel, mock_message
+
+ @helpers.async_test
+ async def test_send_prompt_edits_and_returns_message(self):
+ """The given message should be edited to display the prompt and then should be returned."""
+ msg = helpers.MockMessage()
+ ret_val = await self.syncer._send_prompt(msg)
+
+ msg.edit.assert_called_once()
+ self.assertIn("content", msg.edit.call_args[1])
+ self.assertEqual(ret_val, msg)
+
+ @helpers.async_test
+ async def test_send_prompt_gets_dev_core_channel(self):
+ """The dev-core channel should be retrieved if an extant message isn't given."""
+ subtests = (
+ (self.bot.get_channel, self.mock_get_channel),
+ (self.bot.fetch_channel, self.mock_fetch_channel),
+ )
+
+ for method, mock_ in subtests:
+ with self.subTest(method=method, msg=mock_.__name__):
+ mock_()
+ await self.syncer._send_prompt()
+
+ method.assert_called_once_with(constants.Channels.dev_core)
+
+ @helpers.async_test
+ async def test_send_prompt_returns_None_if_channel_fetch_fails(self):
+ """None should be returned if there's an HTTPException when fetching the channel."""
+ self.bot.get_channel.return_value = None
+ self.bot.fetch_channel.side_effect = discord.HTTPException(mock.MagicMock(), "test error!")
+
+ ret_val = await self.syncer._send_prompt()
+
+ self.assertIsNone(ret_val)
+
+ @helpers.async_test
+ async def test_send_prompt_sends_and_returns_new_message_if_not_given(self):
+ """A new message mentioning core devs should be sent and returned if message isn't given."""
+ for mock_ in (self.mock_get_channel, self.mock_fetch_channel):
+ with self.subTest(msg=mock_.__name__):
+ mock_channel, mock_message = mock_()
+ ret_val = await self.syncer._send_prompt()
+
+ mock_channel.send.assert_called_once()
+ self.assertIn(self.syncer._CORE_DEV_MENTION, mock_channel.send.call_args[0][0])
+ self.assertEqual(ret_val, mock_message)
+
+ @helpers.async_test
+ async def test_send_prompt_adds_reactions(self):
+ """The message should have reactions for confirmation added."""
+ extant_message = helpers.MockMessage()
+ subtests = (
+ (extant_message, lambda: (None, extant_message)),
+ (None, self.mock_get_channel),
+ (None, self.mock_fetch_channel),
+ )
+
+ for message_arg, mock_ in subtests:
+ subtest_msg = "Extant message" if mock_.__name__ == "<lambda>" else mock_.__name__
+
+ with self.subTest(msg=subtest_msg):
+ _, mock_message = mock_()
+ await self.syncer._send_prompt(message_arg)
+
+ calls = [mock.call(emoji) for emoji in self.syncer._REACTION_EMOJIS]
+ mock_message.add_reaction.assert_has_calls(calls)
+
+
+class SyncerConfirmationTests(unittest.TestCase):
+ """Tests for waiting for a sync confirmation reaction on the prompt."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot()
+ self.syncer = TestSyncer(self.bot)
+ self.core_dev_role = helpers.MockRole(id=constants.Roles.core_developers)
+
+ @staticmethod
+ def get_message_reaction(emoji):
+ """Fixture to return a mock message an reaction from the given `emoji`."""
+ message = helpers.MockMessage()
+ reaction = helpers.MockReaction(emoji=emoji, message=message)
+
+ return message, reaction
+
+ def test_reaction_check_for_valid_emoji_and_authors(self):
+ """Should return True if authors are identical or are a bot and a core dev, respectively."""
+ user_subtests = (
+ (
+ helpers.MockMember(id=77),
+ helpers.MockMember(id=77),
+ "identical users",
+ ),
+ (
+ helpers.MockMember(id=77, bot=True),
+ helpers.MockMember(id=43, roles=[self.core_dev_role]),
+ "bot author and core-dev reactor",
+ ),
+ )
+
+ for emoji in self.syncer._REACTION_EMOJIS:
+ for author, user, msg in user_subtests:
+ with self.subTest(author=author, user=user, emoji=emoji, msg=msg):
+ message, reaction = self.get_message_reaction(emoji)
+ ret_val = self.syncer._reaction_check(author, message, reaction, user)
+
+ self.assertTrue(ret_val)
+
+ def test_reaction_check_for_invalid_reactions(self):
+ """Should return False for invalid reaction events."""
+ valid_emoji = self.syncer._REACTION_EMOJIS[0]
+ subtests = (
+ (
+ helpers.MockMember(id=77),
+ *self.get_message_reaction(valid_emoji),
+ helpers.MockMember(id=43, roles=[self.core_dev_role]),
+ "users are not identical",
+ ),
+ (
+ helpers.MockMember(id=77, bot=True),
+ *self.get_message_reaction(valid_emoji),
+ helpers.MockMember(id=43),
+ "reactor lacks the core-dev role",
+ ),
+ (
+ helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]),
+ *self.get_message_reaction(valid_emoji),
+ helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]),
+ "reactor is a bot",
+ ),
+ (
+ helpers.MockMember(id=77),
+ helpers.MockMessage(id=95),
+ helpers.MockReaction(emoji=valid_emoji, message=helpers.MockMessage(id=26)),
+ helpers.MockMember(id=77),
+ "messages are not identical",
+ ),
+ (
+ helpers.MockMember(id=77),
+ *self.get_message_reaction("InVaLiD"),
+ helpers.MockMember(id=77),
+ "emoji is invalid",
+ ),
+ )
+
+ for *args, msg in subtests:
+ kwargs = dict(zip(("author", "message", "reaction", "user"), args))
+ with self.subTest(**kwargs, msg=msg):
+ ret_val = self.syncer._reaction_check(*args)
+ self.assertFalse(ret_val)
+
+ @helpers.async_test
+ async def test_wait_for_confirmation(self):
+ """The message should always be edited and only return True if the emoji is a check mark."""
+ subtests = (
+ (constants.Emojis.check_mark, True, None),
+ ("InVaLiD", False, None),
+ (None, False, TimeoutError),
+ )
+
+ for emoji, ret_val, side_effect in subtests:
+ for bot in (True, False):
+ with self.subTest(emoji=emoji, ret_val=ret_val, side_effect=side_effect, bot=bot):
+ # Set up mocks
+ message = helpers.MockMessage()
+ member = helpers.MockMember(bot=bot)
+
+ self.bot.wait_for.reset_mock()
+ self.bot.wait_for.return_value = (helpers.MockReaction(emoji=emoji), None)
+ self.bot.wait_for.side_effect = side_effect
+
+ # Call the function
+ actual_return = await self.syncer._wait_for_confirmation(member, message)
+
+ # Perform assertions
+ self.bot.wait_for.assert_called_once()
+ self.assertIn("reaction_add", self.bot.wait_for.call_args[0])
+
+ message.edit.assert_called_once()
+ kwargs = message.edit.call_args[1]
+ self.assertIn("content", kwargs)
+
+ # Core devs should only be mentioned if the author is a bot.
+ if bot:
+ self.assertIn(self.syncer._CORE_DEV_MENTION, kwargs["content"])
+ else:
+ self.assertNotIn(self.syncer._CORE_DEV_MENTION, kwargs["content"])
+
+ self.assertIs(actual_return, ret_val)
+
+
+class SyncerSyncTests(unittest.TestCase):
+ """Tests for main function orchestrating the sync."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot(user=helpers.MockMember(bot=True))
+ self.syncer = TestSyncer(self.bot)
+
+ @helpers.async_test
+ async def test_sync_respects_confirmation_result(self):
+ """The sync should abort if confirmation fails and continue if confirmed."""
+ mock_message = helpers.MockMessage()
+ subtests = (
+ (True, mock_message),
+ (False, None),
+ )
+
+ for confirmed, message in subtests:
+ with self.subTest(confirmed=confirmed):
+ self.syncer._sync.reset_mock()
+ self.syncer._get_diff.reset_mock()
+
+ diff = _Diff({1, 2, 3}, {4, 5}, None)
+ self.syncer._get_diff.return_value = diff
+ self.syncer._get_confirmation_result = helpers.AsyncMock(
+ return_value=(confirmed, message)
+ )
+
+ guild = helpers.MockGuild()
+ await self.syncer.sync(guild)
+
+ self.syncer._get_diff.assert_called_once_with(guild)
+ self.syncer._get_confirmation_result.assert_called_once()
+
+ if confirmed:
+ self.syncer._sync.assert_called_once_with(diff)
+ else:
+ self.syncer._sync.assert_not_called()
+
+ @helpers.async_test
+ async def test_sync_diff_size(self):
+ """The diff size should be correctly calculated."""
+ subtests = (
+ (6, _Diff({1, 2}, {3, 4}, {5, 6})),
+ (5, _Diff({1, 2, 3}, None, {4, 5})),
+ (0, _Diff(None, None, None)),
+ (0, _Diff(set(), set(), set())),
+ )
+
+ for size, diff in subtests:
+ with self.subTest(size=size, diff=diff):
+ self.syncer._get_diff.reset_mock()
+ self.syncer._get_diff.return_value = diff
+ self.syncer._get_confirmation_result = helpers.AsyncMock(return_value=(False, None))
+
+ guild = helpers.MockGuild()
+ await self.syncer.sync(guild)
+
+ self.syncer._get_diff.assert_called_once_with(guild)
+ self.syncer._get_confirmation_result.assert_called_once()
+ self.assertEqual(self.syncer._get_confirmation_result.call_args[0][0], size)
+
+ @helpers.async_test
+ async def test_sync_message_edited(self):
+ """The message should be edited if one was sent, even if the sync has an API error."""
+ subtests = (
+ (None, None, False),
+ (helpers.MockMessage(), None, True),
+ (helpers.MockMessage(), ResponseCodeError(mock.MagicMock()), True),
+ )
+
+ for message, side_effect, should_edit in subtests:
+ with self.subTest(message=message, side_effect=side_effect, should_edit=should_edit):
+ self.syncer._sync.side_effect = side_effect
+ self.syncer._get_confirmation_result = helpers.AsyncMock(
+ return_value=(True, message)
+ )
+
+ guild = helpers.MockGuild()
+ await self.syncer.sync(guild)
+
+ if should_edit:
+ message.edit.assert_called_once()
+ self.assertIn("content", message.edit.call_args[1])
+
+ @helpers.async_test
+ async def test_sync_confirmation_context_redirect(self):
+ """If ctx is given, a new message should be sent and author should be ctx's author."""
+ mock_member = helpers.MockMember()
+ subtests = (
+ (None, self.bot.user, None),
+ (helpers.MockContext(author=mock_member), mock_member, helpers.MockMessage()),
+ )
+
+ for ctx, author, message in subtests:
+ with self.subTest(ctx=ctx, author=author, message=message):
+ if ctx is not None:
+ ctx.send.return_value = message
+
+ self.syncer._get_confirmation_result = helpers.AsyncMock(return_value=(False, None))
+
+ guild = helpers.MockGuild()
+ await self.syncer.sync(guild, ctx)
+
+ if ctx is not None:
+ ctx.send.assert_called_once()
+
+ self.syncer._get_confirmation_result.assert_called_once()
+ self.assertEqual(self.syncer._get_confirmation_result.call_args[0][1], author)
+ self.assertEqual(self.syncer._get_confirmation_result.call_args[0][2], message)
+
+ @mock.patch.object(constants.Sync, "max_diff", new=3)
+ @helpers.async_test
+ async def test_confirmation_result_small_diff(self):
+ """Should always return True and the given message if the diff size is too small."""
+ author = helpers.MockMember()
+ expected_message = helpers.MockMessage()
+
+ for size in (3, 2):
+ with self.subTest(size=size):
+ self.syncer._send_prompt = helpers.AsyncMock()
+ self.syncer._wait_for_confirmation = helpers.AsyncMock()
+
+ coro = self.syncer._get_confirmation_result(size, author, expected_message)
+ result, actual_message = await coro
+
+ self.assertTrue(result)
+ self.assertEqual(actual_message, expected_message)
+ self.syncer._send_prompt.assert_not_called()
+ self.syncer._wait_for_confirmation.assert_not_called()
+
+ @mock.patch.object(constants.Sync, "max_diff", new=3)
+ @helpers.async_test
+ async def test_confirmation_result_large_diff(self):
+ """Should return True if confirmed and False if _send_prompt fails or aborted."""
+ author = helpers.MockMember()
+ mock_message = helpers.MockMessage()
+
+ subtests = (
+ (True, mock_message, True, "confirmed"),
+ (False, None, False, "_send_prompt failed"),
+ (False, mock_message, False, "aborted"),
+ )
+
+ for expected_result, expected_message, confirmed, msg in subtests:
+ with self.subTest(msg=msg):
+ self.syncer._send_prompt = helpers.AsyncMock(return_value=expected_message)
+ self.syncer._wait_for_confirmation = helpers.AsyncMock(return_value=confirmed)
+
+ coro = self.syncer._get_confirmation_result(4, author)
+ actual_result, actual_message = await coro
+
+ self.syncer._send_prompt.assert_called_once_with(None) # message defaults to None
+ self.assertIs(actual_result, expected_result)
+ self.assertEqual(actual_message, expected_message)
+
+ if expected_message:
+ self.syncer._wait_for_confirmation.assert_called_once_with(
+ author, expected_message
+ )
diff --git a/tests/bot/cogs/sync/test_cog.py b/tests/bot/cogs/sync/test_cog.py
new file mode 100644
index 000000000..98c9afc0d
--- /dev/null
+++ b/tests/bot/cogs/sync/test_cog.py
@@ -0,0 +1,395 @@
+import unittest
+from unittest import mock
+
+import discord
+
+from bot import constants
+from bot.api import ResponseCodeError
+from bot.cogs import sync
+from bot.cogs.sync.syncers import Syncer
+from tests import helpers
+from tests.base import CommandTestCase
+
+
+class MockSyncer(helpers.CustomMockMixin, mock.MagicMock):
+ """
+ A MagicMock subclass to mock Syncer objects.
+
+ Instances of this class will follow the specifications of `bot.cogs.sync.syncers.Syncer`
+ instances. For more information, see the `MockGuild` docstring.
+ """
+
+ def __init__(self, **kwargs) -> None:
+ super().__init__(spec_set=Syncer, **kwargs)
+
+
+class SyncExtensionTests(unittest.TestCase):
+ """Tests for the sync extension."""
+
+ @staticmethod
+ def test_extension_setup():
+ """The Sync cog should be added."""
+ bot = helpers.MockBot()
+ sync.setup(bot)
+ bot.add_cog.assert_called_once()
+
+
+class SyncCogTestCase(unittest.TestCase):
+ """Base class for Sync cog tests. Sets up patches for syncers."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot()
+
+ # These patch the type. When the type is called, a MockSyncer instanced is returned.
+ # MockSyncer is needed so that our custom AsyncMock is used.
+ # TODO: Use autospec instead in 3.8, which will automatically use AsyncMock when needed.
+ self.role_syncer_patcher = mock.patch(
+ "bot.cogs.sync.syncers.RoleSyncer",
+ new=mock.MagicMock(return_value=MockSyncer())
+ )
+ self.user_syncer_patcher = mock.patch(
+ "bot.cogs.sync.syncers.UserSyncer",
+ new=mock.MagicMock(return_value=MockSyncer())
+ )
+ self.RoleSyncer = self.role_syncer_patcher.start()
+ self.UserSyncer = self.user_syncer_patcher.start()
+
+ self.cog = sync.Sync(self.bot)
+
+ def tearDown(self):
+ self.role_syncer_patcher.stop()
+ self.user_syncer_patcher.stop()
+
+ @staticmethod
+ def response_error(status: int) -> ResponseCodeError:
+ """Fixture to return a ResponseCodeError with the given status code."""
+ response = mock.MagicMock()
+ response.status = status
+
+ return ResponseCodeError(response)
+
+
+class SyncCogTests(SyncCogTestCase):
+ """Tests for the Sync cog."""
+
+ @mock.patch.object(sync.Sync, "sync_guild")
+ def test_sync_cog_init(self, sync_guild):
+ """Should instantiate syncers and run a sync for the guild."""
+ # Reset because a Sync cog was already instantiated in setUp.
+ self.RoleSyncer.reset_mock()
+ self.UserSyncer.reset_mock()
+ self.bot.loop.create_task.reset_mock()
+
+ mock_sync_guild_coro = mock.MagicMock()
+ sync_guild.return_value = mock_sync_guild_coro
+
+ sync.Sync(self.bot)
+
+ self.RoleSyncer.assert_called_once_with(self.bot)
+ self.UserSyncer.assert_called_once_with(self.bot)
+ sync_guild.assert_called_once_with()
+ self.bot.loop.create_task.assert_called_once_with(mock_sync_guild_coro)
+
+ @helpers.async_test
+ async def test_sync_cog_sync_guild(self):
+ """Roles and users should be synced only if a guild is successfully retrieved."""
+ for guild in (helpers.MockGuild(), None):
+ with self.subTest(guild=guild):
+ self.bot.reset_mock()
+ self.cog.role_syncer.reset_mock()
+ self.cog.user_syncer.reset_mock()
+
+ self.bot.get_guild = mock.MagicMock(return_value=guild)
+
+ await self.cog.sync_guild()
+
+ self.bot.wait_until_guild_available.assert_called_once()
+ self.bot.get_guild.assert_called_once_with(constants.Guild.id)
+
+ if guild is None:
+ self.cog.role_syncer.sync.assert_not_called()
+ self.cog.user_syncer.sync.assert_not_called()
+ else:
+ self.cog.role_syncer.sync.assert_called_once_with(guild)
+ self.cog.user_syncer.sync.assert_called_once_with(guild)
+
+ async def patch_user_helper(self, side_effect: BaseException) -> None:
+ """Helper to set a side effect for bot.api_client.patch and then assert it is called."""
+ self.bot.api_client.patch.reset_mock(side_effect=True)
+ self.bot.api_client.patch.side_effect = side_effect
+
+ user_id, updated_information = 5, {"key": 123}
+ await self.cog.patch_user(user_id, updated_information)
+
+ self.bot.api_client.patch.assert_called_once_with(
+ f"bot/users/{user_id}",
+ json=updated_information,
+ )
+
+ @helpers.async_test
+ async def test_sync_cog_patch_user(self):
+ """A PATCH request should be sent and 404 errors ignored."""
+ for side_effect in (None, self.response_error(404)):
+ with self.subTest(side_effect=side_effect):
+ await self.patch_user_helper(side_effect)
+
+ @helpers.async_test
+ async def test_sync_cog_patch_user_non_404(self):
+ """A PATCH request should be sent and the error raised if it's not a 404."""
+ with self.assertRaises(ResponseCodeError):
+ await self.patch_user_helper(self.response_error(500))
+
+
+class SyncCogListenerTests(SyncCogTestCase):
+ """Tests for the listeners of the Sync cog."""
+
+ def setUp(self):
+ super().setUp()
+ self.cog.patch_user = helpers.AsyncMock(spec_set=self.cog.patch_user)
+
+ @helpers.async_test
+ async def test_sync_cog_on_guild_role_create(self):
+ """A POST request should be sent with the new role's data."""
+ self.assertTrue(self.cog.on_guild_role_create.__cog_listener__)
+
+ role_data = {
+ "colour": 49,
+ "id": 777,
+ "name": "rolename",
+ "permissions": 8,
+ "position": 23,
+ }
+ role = helpers.MockRole(**role_data)
+ await self.cog.on_guild_role_create(role)
+
+ self.bot.api_client.post.assert_called_once_with("bot/roles", json=role_data)
+
+ @helpers.async_test
+ async def test_sync_cog_on_guild_role_delete(self):
+ """A DELETE request should be sent."""
+ self.assertTrue(self.cog.on_guild_role_delete.__cog_listener__)
+
+ role = helpers.MockRole(id=99)
+ await self.cog.on_guild_role_delete(role)
+
+ self.bot.api_client.delete.assert_called_once_with("bot/roles/99")
+
+ @helpers.async_test
+ async def test_sync_cog_on_guild_role_update(self):
+ """A PUT request should be sent if the colour, name, permissions, or position changes."""
+ self.assertTrue(self.cog.on_guild_role_update.__cog_listener__)
+
+ role_data = {
+ "colour": 49,
+ "id": 777,
+ "name": "rolename",
+ "permissions": 8,
+ "position": 23,
+ }
+ subtests = (
+ (True, ("colour", "name", "permissions", "position")),
+ (False, ("hoist", "mentionable")),
+ )
+
+ for should_put, attributes in subtests:
+ for attribute in attributes:
+ with self.subTest(should_put=should_put, changed_attribute=attribute):
+ self.bot.api_client.put.reset_mock()
+
+ after_role_data = role_data.copy()
+ after_role_data[attribute] = 876
+
+ before_role = helpers.MockRole(**role_data)
+ after_role = helpers.MockRole(**after_role_data)
+
+ await self.cog.on_guild_role_update(before_role, after_role)
+
+ if should_put:
+ self.bot.api_client.put.assert_called_once_with(
+ f"bot/roles/{after_role.id}",
+ json=after_role_data
+ )
+ else:
+ self.bot.api_client.put.assert_not_called()
+
+ @helpers.async_test
+ async def test_sync_cog_on_member_remove(self):
+ """Member should patched to set in_guild as False."""
+ self.assertTrue(self.cog.on_member_remove.__cog_listener__)
+
+ member = helpers.MockMember()
+ await self.cog.on_member_remove(member)
+
+ self.cog.patch_user.assert_called_once_with(
+ member.id,
+ updated_information={"in_guild": False}
+ )
+
+ @helpers.async_test
+ async def test_sync_cog_on_member_update_roles(self):
+ """Members should be patched if their roles have changed."""
+ self.assertTrue(self.cog.on_member_update.__cog_listener__)
+
+ # Roles are intentionally unsorted.
+ before_roles = [helpers.MockRole(id=12), helpers.MockRole(id=30), helpers.MockRole(id=20)]
+ before_member = helpers.MockMember(roles=before_roles)
+ after_member = helpers.MockMember(roles=before_roles[1:])
+
+ await self.cog.on_member_update(before_member, after_member)
+
+ data = {"roles": sorted(role.id for role in after_member.roles)}
+ self.cog.patch_user.assert_called_once_with(after_member.id, updated_information=data)
+
+ @helpers.async_test
+ async def test_sync_cog_on_member_update_other(self):
+ """Members should not be patched if other attributes have changed."""
+ self.assertTrue(self.cog.on_member_update.__cog_listener__)
+
+ subtests = (
+ ("activities", discord.Game("Pong"), discord.Game("Frogger")),
+ ("nick", "old nick", "new nick"),
+ ("status", discord.Status.online, discord.Status.offline),
+ )
+
+ for attribute, old_value, new_value in subtests:
+ with self.subTest(attribute=attribute):
+ self.cog.patch_user.reset_mock()
+
+ before_member = helpers.MockMember(**{attribute: old_value})
+ after_member = helpers.MockMember(**{attribute: new_value})
+
+ await self.cog.on_member_update(before_member, after_member)
+
+ self.cog.patch_user.assert_not_called()
+
+ @helpers.async_test
+ async def test_sync_cog_on_user_update(self):
+ """A user should be patched only if the name, discriminator, or avatar changes."""
+ self.assertTrue(self.cog.on_user_update.__cog_listener__)
+
+ before_data = {
+ "name": "old name",
+ "discriminator": "1234",
+ "avatar": "old avatar",
+ "bot": False,
+ }
+
+ subtests = (
+ (True, "name", "name", "new name", "new name"),
+ (True, "discriminator", "discriminator", "8765", 8765),
+ (True, "avatar", "avatar_hash", "9j2e9", "9j2e9"),
+ (False, "bot", "bot", True, True),
+ )
+
+ for should_patch, attribute, api_field, value, api_value in subtests:
+ with self.subTest(attribute=attribute):
+ self.cog.patch_user.reset_mock()
+
+ after_data = before_data.copy()
+ after_data[attribute] = value
+ before_user = helpers.MockUser(**before_data)
+ after_user = helpers.MockUser(**after_data)
+
+ await self.cog.on_user_update(before_user, after_user)
+
+ if should_patch:
+ self.cog.patch_user.assert_called_once()
+
+ # Don't care if *all* keys are present; only the changed one is required
+ call_args = self.cog.patch_user.call_args
+ self.assertEqual(call_args[0][0], after_user.id)
+ self.assertIn("updated_information", call_args[1])
+
+ updated_information = call_args[1]["updated_information"]
+ self.assertIn(api_field, updated_information)
+ self.assertEqual(updated_information[api_field], api_value)
+ else:
+ self.cog.patch_user.assert_not_called()
+
+ async def on_member_join_helper(self, side_effect: Exception) -> dict:
+ """
+ Helper to set `side_effect` for on_member_join and assert a PUT request was sent.
+
+ The request data for the mock member is returned. All exceptions will be re-raised.
+ """
+ member = helpers.MockMember(
+ discriminator="1234",
+ roles=[helpers.MockRole(id=22), helpers.MockRole(id=12)],
+ )
+
+ data = {
+ "avatar_hash": member.avatar,
+ "discriminator": int(member.discriminator),
+ "id": member.id,
+ "in_guild": True,
+ "name": member.name,
+ "roles": sorted(role.id for role in member.roles)
+ }
+
+ self.bot.api_client.put.reset_mock(side_effect=True)
+ self.bot.api_client.put.side_effect = side_effect
+
+ try:
+ await self.cog.on_member_join(member)
+ except Exception:
+ raise
+ finally:
+ self.bot.api_client.put.assert_called_once_with(
+ f"bot/users/{member.id}",
+ json=data
+ )
+
+ return data
+
+ @helpers.async_test
+ async def test_sync_cog_on_member_join(self):
+ """Should PUT user's data or POST it if the user doesn't exist."""
+ for side_effect in (None, self.response_error(404)):
+ with self.subTest(side_effect=side_effect):
+ self.bot.api_client.post.reset_mock()
+ data = await self.on_member_join_helper(side_effect)
+
+ if side_effect:
+ self.bot.api_client.post.assert_called_once_with("bot/users", json=data)
+ else:
+ self.bot.api_client.post.assert_not_called()
+
+ @helpers.async_test
+ async def test_sync_cog_on_member_join_non_404(self):
+ """ResponseCodeError should be re-raised if status code isn't a 404."""
+ with self.assertRaises(ResponseCodeError):
+ await self.on_member_join_helper(self.response_error(500))
+
+ self.bot.api_client.post.assert_not_called()
+
+
+class SyncCogCommandTests(SyncCogTestCase, CommandTestCase):
+ """Tests for the commands in the Sync cog."""
+
+ @helpers.async_test
+ async def test_sync_roles_command(self):
+ """sync() should be called on the RoleSyncer."""
+ ctx = helpers.MockContext()
+ await self.cog.sync_roles_command.callback(self.cog, ctx)
+
+ self.cog.role_syncer.sync.assert_called_once_with(ctx.guild, ctx)
+
+ @helpers.async_test
+ async def test_sync_users_command(self):
+ """sync() should be called on the UserSyncer."""
+ ctx = helpers.MockContext()
+ await self.cog.sync_users_command.callback(self.cog, ctx)
+
+ self.cog.user_syncer.sync.assert_called_once_with(ctx.guild, ctx)
+
+ def test_commands_require_admin(self):
+ """The sync commands should only run if the author has the administrator permission."""
+ cmds = (
+ self.cog.sync_group,
+ self.cog.sync_roles_command,
+ self.cog.sync_users_command,
+ )
+
+ for cmd in cmds:
+ with self.subTest(cmd=cmd):
+ self.assertHasPermissionsCheck(cmd, {"administrator": True})
diff --git a/tests/bot/cogs/sync/test_roles.py b/tests/bot/cogs/sync/test_roles.py
index 27ae27639..14fb2577a 100644
--- a/tests/bot/cogs/sync/test_roles.py
+++ b/tests/bot/cogs/sync/test_roles.py
@@ -1,126 +1,165 @@
import unittest
+from unittest import mock
-from bot.cogs.sync.syncers import Role, get_roles_for_sync
-
-
-class GetRolesForSyncTests(unittest.TestCase):
- """Tests constructing the roles to synchronize with the site."""
-
- def test_get_roles_for_sync_empty_return_for_equal_roles(self):
- """No roles should be synced when no diff is found."""
- api_roles = {Role(id=41, name='name', colour=33, permissions=0x8, position=1)}
- guild_roles = {Role(id=41, name='name', colour=33, permissions=0x8, position=1)}
-
- self.assertEqual(
- get_roles_for_sync(guild_roles, api_roles),
- (set(), set(), set())
- )
-
- def test_get_roles_for_sync_returns_roles_to_update_with_non_id_diff(self):
- """Roles to be synced are returned when non-ID attributes differ."""
- api_roles = {Role(id=41, name='old name', colour=35, permissions=0x8, position=1)}
- guild_roles = {Role(id=41, name='new name', colour=33, permissions=0x8, position=2)}
-
- self.assertEqual(
- get_roles_for_sync(guild_roles, api_roles),
- (set(), guild_roles, set())
- )
-
- def test_get_roles_only_returns_roles_that_require_update(self):
- """Roles that require an update should be returned as the second tuple element."""
- api_roles = {
- Role(id=41, name='old name', colour=33, permissions=0x8, position=1),
- Role(id=53, name='other role', colour=55, permissions=0, position=3)
- }
- guild_roles = {
- Role(id=41, name='new name', colour=35, permissions=0x8, position=2),
- Role(id=53, name='other role', colour=55, permissions=0, position=3)
- }
-
- self.assertEqual(
- get_roles_for_sync(guild_roles, api_roles),
- (
- set(),
- {Role(id=41, name='new name', colour=35, permissions=0x8, position=2)},
- set(),
- )
- )
-
- def test_get_roles_returns_new_roles_in_first_tuple_element(self):
- """Newly created roles are returned as the first tuple element."""
- api_roles = {
- Role(id=41, name='name', colour=35, permissions=0x8, position=1),
- }
- guild_roles = {
- Role(id=41, name='name', colour=35, permissions=0x8, position=1),
- Role(id=53, name='other role', colour=55, permissions=0, position=2)
- }
-
- self.assertEqual(
- get_roles_for_sync(guild_roles, api_roles),
- (
- {Role(id=53, name='other role', colour=55, permissions=0, position=2)},
- set(),
- set(),
- )
- )
-
- def test_get_roles_returns_roles_to_update_and_new_roles(self):
- """Newly created and updated roles should be returned together."""
- api_roles = {
- Role(id=41, name='old name', colour=35, permissions=0x8, position=1),
- }
- guild_roles = {
- Role(id=41, name='new name', colour=40, permissions=0x16, position=2),
- Role(id=53, name='other role', colour=55, permissions=0, position=3)
- }
-
- self.assertEqual(
- get_roles_for_sync(guild_roles, api_roles),
- (
- {Role(id=53, name='other role', colour=55, permissions=0, position=3)},
- {Role(id=41, name='new name', colour=40, permissions=0x16, position=2)},
- set(),
- )
- )
-
- def test_get_roles_returns_roles_to_delete(self):
- """Roles to be deleted should be returned as the third tuple element."""
- api_roles = {
- Role(id=41, name='name', colour=35, permissions=0x8, position=1),
- Role(id=61, name='to delete', colour=99, permissions=0x9, position=2),
- }
- guild_roles = {
- Role(id=41, name='name', colour=35, permissions=0x8, position=1),
- }
-
- self.assertEqual(
- get_roles_for_sync(guild_roles, api_roles),
- (
- set(),
- set(),
- {Role(id=61, name='to delete', colour=99, permissions=0x9, position=2)},
- )
- )
-
- def test_get_roles_returns_roles_to_delete_update_and_new_roles(self):
- """When roles were added, updated, and removed, all of them are returned properly."""
- api_roles = {
- Role(id=41, name='not changed', colour=35, permissions=0x8, position=1),
- Role(id=61, name='to delete', colour=99, permissions=0x9, position=2),
- Role(id=71, name='to update', colour=99, permissions=0x9, position=3),
- }
- guild_roles = {
- Role(id=41, name='not changed', colour=35, permissions=0x8, position=1),
- Role(id=81, name='to create', colour=99, permissions=0x9, position=4),
- Role(id=71, name='updated', colour=101, permissions=0x5, position=3),
- }
-
- self.assertEqual(
- get_roles_for_sync(guild_roles, api_roles),
- (
- {Role(id=81, name='to create', colour=99, permissions=0x9, position=4)},
- {Role(id=71, name='updated', colour=101, permissions=0x5, position=3)},
- {Role(id=61, name='to delete', colour=99, permissions=0x9, position=2)},
- )
- )
+import discord
+
+from bot.cogs.sync.syncers import RoleSyncer, _Diff, _Role
+from tests import helpers
+
+
+def fake_role(**kwargs):
+ """Fixture to return a dictionary representing a role with default values set."""
+ kwargs.setdefault("id", 9)
+ kwargs.setdefault("name", "fake role")
+ kwargs.setdefault("colour", 7)
+ kwargs.setdefault("permissions", 0)
+ kwargs.setdefault("position", 55)
+
+ return kwargs
+
+
+class RoleSyncerDiffTests(unittest.TestCase):
+ """Tests for determining differences between roles in the DB and roles in the Guild cache."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot()
+ self.syncer = RoleSyncer(self.bot)
+
+ @staticmethod
+ def get_guild(*roles):
+ """Fixture to return a guild object with the given roles."""
+ guild = helpers.MockGuild()
+ guild.roles = []
+
+ for role in roles:
+ mock_role = helpers.MockRole(**role)
+ mock_role.colour = discord.Colour(role["colour"])
+ mock_role.permissions = discord.Permissions(role["permissions"])
+ guild.roles.append(mock_role)
+
+ return guild
+
+ @helpers.async_test
+ async def test_empty_diff_for_identical_roles(self):
+ """No differences should be found if the roles in the guild and DB are identical."""
+ self.bot.api_client.get.return_value = [fake_role()]
+ guild = self.get_guild(fake_role())
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = (set(), set(), set())
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ @helpers.async_test
+ async def test_diff_for_updated_roles(self):
+ """Only updated roles should be added to the 'updated' set of the diff."""
+ updated_role = fake_role(id=41, name="new")
+
+ self.bot.api_client.get.return_value = [fake_role(id=41, name="old"), fake_role()]
+ guild = self.get_guild(updated_role, fake_role())
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = (set(), {_Role(**updated_role)}, set())
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ @helpers.async_test
+ async def test_diff_for_new_roles(self):
+ """Only new roles should be added to the 'created' set of the diff."""
+ new_role = fake_role(id=41, name="new")
+
+ self.bot.api_client.get.return_value = [fake_role()]
+ guild = self.get_guild(fake_role(), new_role)
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = ({_Role(**new_role)}, set(), set())
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ @helpers.async_test
+ async def test_diff_for_deleted_roles(self):
+ """Only deleted roles should be added to the 'deleted' set of the diff."""
+ deleted_role = fake_role(id=61, name="deleted")
+
+ self.bot.api_client.get.return_value = [fake_role(), deleted_role]
+ guild = self.get_guild(fake_role())
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = (set(), set(), {_Role(**deleted_role)})
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ @helpers.async_test
+ async def test_diff_for_new_updated_and_deleted_roles(self):
+ """When roles are added, updated, and removed, all of them are returned properly."""
+ new = fake_role(id=41, name="new")
+ updated = fake_role(id=71, name="updated")
+ deleted = fake_role(id=61, name="deleted")
+
+ self.bot.api_client.get.return_value = [
+ fake_role(),
+ fake_role(id=71, name="updated name"),
+ deleted,
+ ]
+ guild = self.get_guild(fake_role(), new, updated)
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = ({_Role(**new)}, {_Role(**updated)}, {_Role(**deleted)})
+
+ self.assertEqual(actual_diff, expected_diff)
+
+
+class RoleSyncerSyncTests(unittest.TestCase):
+ """Tests for the API requests that sync roles."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot()
+ self.syncer = RoleSyncer(self.bot)
+
+ @helpers.async_test
+ async def test_sync_created_roles(self):
+ """Only POST requests should be made with the correct payload."""
+ roles = [fake_role(id=111), fake_role(id=222)]
+
+ role_tuples = {_Role(**role) for role in roles}
+ diff = _Diff(role_tuples, set(), set())
+ await self.syncer._sync(diff)
+
+ calls = [mock.call("bot/roles", json=role) for role in roles]
+ self.bot.api_client.post.assert_has_calls(calls, any_order=True)
+ self.assertEqual(self.bot.api_client.post.call_count, len(roles))
+
+ self.bot.api_client.put.assert_not_called()
+ self.bot.api_client.delete.assert_not_called()
+
+ @helpers.async_test
+ async def test_sync_updated_roles(self):
+ """Only PUT requests should be made with the correct payload."""
+ roles = [fake_role(id=111), fake_role(id=222)]
+
+ role_tuples = {_Role(**role) for role in roles}
+ diff = _Diff(set(), role_tuples, set())
+ await self.syncer._sync(diff)
+
+ calls = [mock.call(f"bot/roles/{role['id']}", json=role) for role in roles]
+ self.bot.api_client.put.assert_has_calls(calls, any_order=True)
+ self.assertEqual(self.bot.api_client.put.call_count, len(roles))
+
+ self.bot.api_client.post.assert_not_called()
+ self.bot.api_client.delete.assert_not_called()
+
+ @helpers.async_test
+ async def test_sync_deleted_roles(self):
+ """Only DELETE requests should be made with the correct payload."""
+ roles = [fake_role(id=111), fake_role(id=222)]
+
+ role_tuples = {_Role(**role) for role in roles}
+ diff = _Diff(set(), set(), role_tuples)
+ await self.syncer._sync(diff)
+
+ calls = [mock.call(f"bot/roles/{role['id']}") for role in roles]
+ self.bot.api_client.delete.assert_has_calls(calls, any_order=True)
+ self.assertEqual(self.bot.api_client.delete.call_count, len(roles))
+
+ self.bot.api_client.post.assert_not_called()
+ self.bot.api_client.put.assert_not_called()
diff --git a/tests/bot/cogs/sync/test_users.py b/tests/bot/cogs/sync/test_users.py
index ccaf67490..421bf6bb6 100644
--- a/tests/bot/cogs/sync/test_users.py
+++ b/tests/bot/cogs/sync/test_users.py
@@ -1,84 +1,169 @@
import unittest
+from unittest import mock
-from bot.cogs.sync.syncers import User, get_users_for_sync
+from bot.cogs.sync.syncers import UserSyncer, _Diff, _User
+from tests import helpers
def fake_user(**kwargs):
- kwargs.setdefault('id', 43)
- kwargs.setdefault('name', 'bob the test man')
- kwargs.setdefault('discriminator', 1337)
- kwargs.setdefault('avatar_hash', None)
- kwargs.setdefault('roles', (666,))
- kwargs.setdefault('in_guild', True)
- return User(**kwargs)
-
-
-class GetUsersForSyncTests(unittest.TestCase):
- """Tests constructing the users to synchronize with the site."""
-
- def test_get_users_for_sync_returns_nothing_for_empty_params(self):
- """When no users are given, none are returned."""
- self.assertEqual(
- get_users_for_sync({}, {}),
- (set(), set())
- )
-
- def test_get_users_for_sync_returns_nothing_for_equal_users(self):
- """When no users are updated, none are returned."""
- api_users = {43: fake_user()}
- guild_users = {43: fake_user()}
-
- self.assertEqual(
- get_users_for_sync(guild_users, api_users),
- (set(), set())
- )
-
- def test_get_users_for_sync_returns_users_to_update_on_non_id_field_diff(self):
- """When a non-ID-field differs, the user to update is returned."""
- api_users = {43: fake_user()}
- guild_users = {43: fake_user(name='new fancy name')}
-
- self.assertEqual(
- get_users_for_sync(guild_users, api_users),
- (set(), {fake_user(name='new fancy name')})
- )
-
- def test_get_users_for_sync_returns_users_to_create_with_new_ids_on_guild(self):
- """When new users join the guild, they are returned as the first tuple element."""
- api_users = {43: fake_user()}
- guild_users = {43: fake_user(), 63: fake_user(id=63)}
-
- self.assertEqual(
- get_users_for_sync(guild_users, api_users),
- ({fake_user(id=63)}, set())
- )
-
- def test_get_users_for_sync_updates_in_guild_field_on_user_leave(self):
+ """Fixture to return a dictionary representing a user with default values set."""
+ kwargs.setdefault("id", 43)
+ kwargs.setdefault("name", "bob the test man")
+ kwargs.setdefault("discriminator", 1337)
+ kwargs.setdefault("avatar_hash", None)
+ kwargs.setdefault("roles", (666,))
+ kwargs.setdefault("in_guild", True)
+
+ return kwargs
+
+
+class UserSyncerDiffTests(unittest.TestCase):
+ """Tests for determining differences between users in the DB and users in the Guild cache."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot()
+ self.syncer = UserSyncer(self.bot)
+
+ @staticmethod
+ def get_guild(*members):
+ """Fixture to return a guild object with the given members."""
+ guild = helpers.MockGuild()
+ guild.members = []
+
+ for member in members:
+ member = member.copy()
+ member["avatar"] = member.pop("avatar_hash")
+ del member["in_guild"]
+
+ mock_member = helpers.MockMember(**member)
+ mock_member.roles = [helpers.MockRole(id=role_id) for role_id in member["roles"]]
+
+ guild.members.append(mock_member)
+
+ return guild
+
+ @helpers.async_test
+ async def test_empty_diff_for_no_users(self):
+ """When no users are given, an empty diff should be returned."""
+ guild = self.get_guild()
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = (set(), set(), None)
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ @helpers.async_test
+ async def test_empty_diff_for_identical_users(self):
+ """No differences should be found if the users in the guild and DB are identical."""
+ self.bot.api_client.get.return_value = [fake_user()]
+ guild = self.get_guild(fake_user())
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = (set(), set(), None)
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ @helpers.async_test
+ async def test_diff_for_updated_users(self):
+ """Only updated users should be added to the 'updated' set of the diff."""
+ updated_user = fake_user(id=99, name="new")
+
+ self.bot.api_client.get.return_value = [fake_user(id=99, name="old"), fake_user()]
+ guild = self.get_guild(updated_user, fake_user())
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = (set(), {_User(**updated_user)}, None)
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ @helpers.async_test
+ async def test_diff_for_new_users(self):
+ """Only new users should be added to the 'created' set of the diff."""
+ new_user = fake_user(id=99, name="new")
+
+ self.bot.api_client.get.return_value = [fake_user()]
+ guild = self.get_guild(fake_user(), new_user)
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = ({_User(**new_user)}, set(), None)
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ @helpers.async_test
+ async def test_diff_sets_in_guild_false_for_leaving_users(self):
"""When a user leaves the guild, the `in_guild` flag is updated to `False`."""
- api_users = {43: fake_user(), 63: fake_user(id=63)}
- guild_users = {43: fake_user()}
-
- self.assertEqual(
- get_users_for_sync(guild_users, api_users),
- (set(), {fake_user(id=63, in_guild=False)})
- )
-
- def test_get_users_for_sync_updates_and_creates_users_as_needed(self):
- """When one user left and another one was updated, both are returned."""
- api_users = {43: fake_user()}
- guild_users = {63: fake_user(id=63)}
-
- self.assertEqual(
- get_users_for_sync(guild_users, api_users),
- ({fake_user(id=63)}, {fake_user(in_guild=False)})
- )
-
- def test_get_users_for_sync_does_not_duplicate_update_users(self):
- """When the API knows a user the guild doesn't, nothing is performed."""
- api_users = {43: fake_user(in_guild=False)}
- guild_users = {}
-
- self.assertEqual(
- get_users_for_sync(guild_users, api_users),
- (set(), set())
- )
+ leaving_user = fake_user(id=63, in_guild=False)
+
+ self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63)]
+ guild = self.get_guild(fake_user())
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = (set(), {_User(**leaving_user)}, None)
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ @helpers.async_test
+ async def test_diff_for_new_updated_and_leaving_users(self):
+ """When users are added, updated, and removed, all of them are returned properly."""
+ new_user = fake_user(id=99, name="new")
+ updated_user = fake_user(id=55, name="updated")
+ leaving_user = fake_user(id=63, in_guild=False)
+
+ self.bot.api_client.get.return_value = [fake_user(), fake_user(id=55), fake_user(id=63)]
+ guild = self.get_guild(fake_user(), new_user, updated_user)
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = ({_User(**new_user)}, {_User(**updated_user), _User(**leaving_user)}, None)
+
+ self.assertEqual(actual_diff, expected_diff)
+
+ @helpers.async_test
+ async def test_empty_diff_for_db_users_not_in_guild(self):
+ """When the DB knows a user the guild doesn't, no difference is found."""
+ self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)]
+ guild = self.get_guild(fake_user())
+
+ actual_diff = await self.syncer._get_diff(guild)
+ expected_diff = (set(), set(), None)
+
+ self.assertEqual(actual_diff, expected_diff)
+
+
+class UserSyncerSyncTests(unittest.TestCase):
+ """Tests for the API requests that sync users."""
+
+ def setUp(self):
+ self.bot = helpers.MockBot()
+ self.syncer = UserSyncer(self.bot)
+
+ @helpers.async_test
+ async def test_sync_created_users(self):
+ """Only POST requests should be made with the correct payload."""
+ users = [fake_user(id=111), fake_user(id=222)]
+
+ user_tuples = {_User(**user) for user in users}
+ diff = _Diff(user_tuples, set(), None)
+ await self.syncer._sync(diff)
+
+ calls = [mock.call("bot/users", json=user) for user in users]
+ self.bot.api_client.post.assert_has_calls(calls, any_order=True)
+ self.assertEqual(self.bot.api_client.post.call_count, len(users))
+
+ self.bot.api_client.put.assert_not_called()
+ self.bot.api_client.delete.assert_not_called()
+
+ @helpers.async_test
+ async def test_sync_updated_users(self):
+ """Only PUT requests should be made with the correct payload."""
+ users = [fake_user(id=111), fake_user(id=222)]
+
+ user_tuples = {_User(**user) for user in users}
+ diff = _Diff(set(), user_tuples, None)
+ await self.syncer._sync(diff)
+
+ calls = [mock.call(f"bot/users/{user['id']}", json=user) for user in users]
+ self.bot.api_client.put.assert_has_calls(calls, any_order=True)
+ self.assertEqual(self.bot.api_client.put.call_count, len(users))
+
+ self.bot.api_client.post.assert_not_called()
+ self.bot.api_client.delete.assert_not_called()
diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py
index d07b2bce1..5b0a3b8c3 100644
--- a/tests/bot/cogs/test_duck_pond.py
+++ b/tests/bot/cogs/test_duck_pond.py
@@ -54,7 +54,7 @@ class DuckPondTests(base.LoggingTestCase):
asyncio.run(self.cog.fetch_webhook())
- self.bot.wait_until_ready.assert_called_once()
+ self.bot.wait_until_guild_available.assert_called_once()
self.bot.fetch_webhook.assert_called_once_with(1)
self.assertEqual(self.cog.webhook, "dummy webhook")
@@ -67,7 +67,7 @@ class DuckPondTests(base.LoggingTestCase):
with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher:
asyncio.run(self.cog.fetch_webhook())
- self.bot.wait_until_ready.assert_called_once()
+ self.bot.wait_until_guild_available.assert_called_once()
self.bot.fetch_webhook.assert_called_once_with(1)
self.assertEqual(len(log_watcher.records), 1)
diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py
index 4496a2ae0..8443cfe71 100644
--- a/tests/bot/cogs/test_information.py
+++ b/tests/bot/cogs/test_information.py
@@ -19,7 +19,7 @@ class InformationCogTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
- cls.moderator_role = helpers.MockRole(name="Moderator", id=constants.Roles.moderator)
+ cls.moderator_role = helpers.MockRole(name="Moderator", id=constants.Roles.moderators)
def setUp(self):
"""Sets up fresh objects for each test."""
@@ -125,10 +125,10 @@ class InformationCogTests(unittest.TestCase):
)
],
members=[
- *(helpers.MockMember(status='online') for _ in range(2)),
- *(helpers.MockMember(status='idle') for _ in range(1)),
- *(helpers.MockMember(status='dnd') for _ in range(4)),
- *(helpers.MockMember(status='offline') for _ in range(3)),
+ *(helpers.MockMember(status=discord.Status.online) for _ in range(2)),
+ *(helpers.MockMember(status=discord.Status.idle) for _ in range(1)),
+ *(helpers.MockMember(status=discord.Status.dnd) for _ in range(4)),
+ *(helpers.MockMember(status=discord.Status.offline) for _ in range(3)),
],
member_count=1_234,
icon_url='a-lemon.jpg',
@@ -153,9 +153,9 @@ class InformationCogTests(unittest.TestCase):
**Counts**
Members: {self.ctx.guild.member_count:,}
Roles: {len(self.ctx.guild.roles)}
- Text: 1
- Voice: 1
- Channel categories: 1
+ Category channels: 1
+ Text channels: 1
+ Voice channels: 1
**Members**
{constants.Emojis.status_online} 2
@@ -521,7 +521,7 @@ class UserCommandTests(unittest.TestCase):
"""A regular user should not be able to use this command outside of bot-commands."""
constants.MODERATION_ROLES = [self.moderator_role.id]
constants.STAFF_ROLES = [self.moderator_role.id]
- constants.Channels.bot = 50
+ constants.Channels.bot_commands = 50
ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=100))
@@ -533,7 +533,7 @@ class UserCommandTests(unittest.TestCase):
def test_regular_user_may_use_command_in_bot_commands_channel(self, create_embed, constants):
"""A regular user should be allowed to use `!user` targeting themselves in bot-commands."""
constants.STAFF_ROLES = [self.moderator_role.id]
- constants.Channels.bot = 50
+ constants.Channels.bot_commands = 50
ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=50))
@@ -546,7 +546,7 @@ class UserCommandTests(unittest.TestCase):
def test_regular_user_can_explicitly_target_themselves(self, create_embed, constants):
"""A user should target itself with `!user` when a `user` argument was not provided."""
constants.STAFF_ROLES = [self.moderator_role.id]
- constants.Channels.bot = 50
+ constants.Channels.bot_commands = 50
ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=50))
@@ -559,7 +559,7 @@ class UserCommandTests(unittest.TestCase):
def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants):
"""Staff members should be able to bypass the bot-commands channel restriction."""
constants.STAFF_ROLES = [self.moderator_role.id]
- constants.Channels.bot = 50
+ constants.Channels.bot_commands = 50
ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=200))
diff --git a/tests/bot/rules/__init__.py b/tests/bot/rules/__init__.py
index e69de29bb..36c986fe1 100644
--- a/tests/bot/rules/__init__.py
+++ b/tests/bot/rules/__init__.py
@@ -0,0 +1,76 @@
+import unittest
+from abc import ABCMeta, abstractmethod
+from typing import Callable, Dict, Iterable, List, NamedTuple, Tuple
+
+from tests.helpers import MockMessage
+
+
+class DisallowedCase(NamedTuple):
+ """Encapsulation for test cases expected to fail."""
+ recent_messages: List[MockMessage]
+ culprits: Iterable[str]
+ n_violations: int
+
+
+class RuleTest(unittest.TestCase, metaclass=ABCMeta):
+ """
+ Abstract class for antispam rule test cases.
+
+ Tests for specific rules should inherit from `RuleTest` and implement
+ `relevant_messages` and `get_report`. Each instance should also set the
+ `apply` and `config` attributes as necessary.
+
+ The execution of test cases can then be delegated to the `run_allowed`
+ and `run_disallowed` methods.
+ """
+
+ apply: Callable # The tested rule's apply function
+ config: Dict[str, int]
+
+ async def run_allowed(self, cases: Tuple[List[MockMessage], ...]) -> None:
+ """Run all `cases` against `self.apply` expecting them to pass."""
+ for recent_messages in cases:
+ last_message = recent_messages[0]
+
+ with self.subTest(
+ last_message=last_message,
+ recent_messages=recent_messages,
+ config=self.config,
+ ):
+ self.assertIsNone(
+ await self.apply(last_message, recent_messages, self.config)
+ )
+
+ async def run_disallowed(self, cases: Tuple[DisallowedCase, ...]) -> None:
+ """Run all `cases` against `self.apply` expecting them to fail."""
+ for case in cases:
+ recent_messages, culprits, n_violations = case
+ last_message = recent_messages[0]
+ relevant_messages = self.relevant_messages(case)
+ desired_output = (
+ self.get_report(case),
+ culprits,
+ relevant_messages,
+ )
+
+ with self.subTest(
+ last_message=last_message,
+ recent_messages=recent_messages,
+ relevant_messages=relevant_messages,
+ n_violations=n_violations,
+ config=self.config,
+ ):
+ self.assertTupleEqual(
+ await self.apply(last_message, recent_messages, self.config),
+ desired_output,
+ )
+
+ @abstractmethod
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ """Give expected relevant messages for `case`."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def get_report(self, case: DisallowedCase) -> str:
+ """Give expected error report for `case`."""
+ raise NotImplementedError
diff --git a/tests/bot/rules/test_attachments.py b/tests/bot/rules/test_attachments.py
index d7187f315..e54b4b5b8 100644
--- a/tests/bot/rules/test_attachments.py
+++ b/tests/bot/rules/test_attachments.py
@@ -1,98 +1,71 @@
-import unittest
-from typing import List, NamedTuple, Tuple
+from typing import Iterable
from bot.rules import attachments
+from tests.bot.rules import DisallowedCase, RuleTest
from tests.helpers import MockMessage, async_test
-class Case(NamedTuple):
- recent_messages: List[MockMessage]
- culprit: Tuple[str]
- total_attachments: int
-
-
-def msg(author: str, total_attachments: int) -> MockMessage:
+def make_msg(author: str, total_attachments: int) -> MockMessage:
"""Builds a message with `total_attachments` attachments."""
return MockMessage(author=author, attachments=list(range(total_attachments)))
-class AttachmentRuleTests(unittest.TestCase):
+class AttachmentRuleTests(RuleTest):
"""Tests applying the `attachments` antispam rule."""
def setUp(self):
- self.config = {"max": 5}
+ self.apply = attachments.apply
+ self.config = {"max": 5, "interval": 10}
@async_test
async def test_allows_messages_without_too_many_attachments(self):
"""Messages without too many attachments are allowed as-is."""
cases = (
- [msg("bob", 0), msg("bob", 0), msg("bob", 0)],
- [msg("bob", 2), msg("bob", 2)],
- [msg("bob", 2), msg("alice", 2), msg("bob", 2)],
+ [make_msg("bob", 0), make_msg("bob", 0), make_msg("bob", 0)],
+ [make_msg("bob", 2), make_msg("bob", 2)],
+ [make_msg("bob", 2), make_msg("alice", 2), make_msg("bob", 2)],
)
- for recent_messages in cases:
- last_message = recent_messages[0]
-
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- config=self.config
- ):
- self.assertIsNone(
- await attachments.apply(last_message, recent_messages, self.config)
- )
+ await self.run_allowed(cases)
@async_test
async def test_disallows_messages_with_too_many_attachments(self):
"""Messages with too many attachments trigger the rule."""
cases = (
- Case(
- [msg("bob", 4), msg("bob", 0), msg("bob", 6)],
+ DisallowedCase(
+ [make_msg("bob", 4), make_msg("bob", 0), make_msg("bob", 6)],
("bob",),
- 10
+ 10,
),
- Case(
- [msg("bob", 4), msg("alice", 6), msg("bob", 2)],
+ DisallowedCase(
+ [make_msg("bob", 4), make_msg("alice", 6), make_msg("bob", 2)],
("bob",),
- 6
+ 6,
),
- Case(
- [msg("alice", 6)],
+ DisallowedCase(
+ [make_msg("alice", 6)],
("alice",),
- 6
+ 6,
),
- (
- [msg("alice", 1) for _ in range(6)],
+ DisallowedCase(
+ [make_msg("alice", 1) for _ in range(6)],
("alice",),
- 6
+ 6,
),
)
- for recent_messages, culprit, total_attachments in cases:
- last_message = recent_messages[0]
- relevant_messages = tuple(
- msg
- for msg in recent_messages
- if (
- msg.author == last_message.author
- and len(msg.attachments) > 0
- )
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_message = case.recent_messages[0]
+ return tuple(
+ msg
+ for msg in case.recent_messages
+ if (
+ msg.author == last_message.author
+ and len(msg.attachments) > 0
)
+ )
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- relevant_messages=relevant_messages,
- total_attachments=total_attachments,
- config=self.config
- ):
- desired_output = (
- f"sent {total_attachments} attachments in {self.config['max']}s",
- culprit,
- relevant_messages
- )
- self.assertTupleEqual(
- await attachments.apply(last_message, recent_messages, self.config),
- desired_output
- )
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} attachments in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_burst.py b/tests/bot/rules/test_burst.py
new file mode 100644
index 000000000..72f0be0c7
--- /dev/null
+++ b/tests/bot/rules/test_burst.py
@@ -0,0 +1,56 @@
+from typing import Iterable
+
+from bot.rules import burst
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage, async_test
+
+
+def make_msg(author: str) -> MockMessage:
+ """
+ Init a MockMessage instance with author set to `author`.
+
+ This serves as a shorthand / alias to keep the test cases visually clean.
+ """
+ return MockMessage(author=author)
+
+
+class BurstRuleTests(RuleTest):
+ """Tests the `burst` antispam rule."""
+
+ def setUp(self):
+ self.apply = burst.apply
+ self.config = {"max": 2, "interval": 10}
+
+ @async_test
+ async def test_allows_messages_within_limit(self):
+ """Cases which do not violate the rule."""
+ cases = (
+ [make_msg("bob"), make_msg("bob")],
+ [make_msg("bob"), make_msg("alice"), make_msg("bob")],
+ )
+
+ await self.run_allowed(cases)
+
+ @async_test
+ async def test_disallows_messages_beyond_limit(self):
+ """Cases where the amount of messages exceeds the limit, triggering the rule."""
+ cases = (
+ DisallowedCase(
+ [make_msg("bob"), make_msg("bob"), make_msg("bob")],
+ ("bob",),
+ 3,
+ ),
+ DisallowedCase(
+ [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")],
+ ("bob",),
+ 3,
+ ),
+ )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ return tuple(msg for msg in case.recent_messages if msg.author in case.culprits)
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} messages in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_burst_shared.py b/tests/bot/rules/test_burst_shared.py
new file mode 100644
index 000000000..47367a5f8
--- /dev/null
+++ b/tests/bot/rules/test_burst_shared.py
@@ -0,0 +1,59 @@
+from typing import Iterable
+
+from bot.rules import burst_shared
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage, async_test
+
+
+def make_msg(author: str) -> MockMessage:
+ """
+ Init a MockMessage instance with the passed arg.
+
+ This serves as a shorthand / alias to keep the test cases visually clean.
+ """
+ return MockMessage(author=author)
+
+
+class BurstSharedRuleTests(RuleTest):
+ """Tests the `burst_shared` antispam rule."""
+
+ def setUp(self):
+ self.apply = burst_shared.apply
+ self.config = {"max": 2, "interval": 10}
+
+ @async_test
+ async def test_allows_messages_within_limit(self):
+ """
+ Cases that do not violate the rule.
+
+ There really isn't more to test here than a single case.
+ """
+ cases = (
+ [make_msg("spongebob"), make_msg("patrick")],
+ )
+
+ await self.run_allowed(cases)
+
+ @async_test
+ async def test_disallows_messages_beyond_limit(self):
+ """Cases where the amount of messages exceeds the limit, triggering the rule."""
+ cases = (
+ DisallowedCase(
+ [make_msg("bob"), make_msg("bob"), make_msg("bob")],
+ {"bob"},
+ 3,
+ ),
+ DisallowedCase(
+ [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")],
+ {"bob", "alice"},
+ 4,
+ ),
+ )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ return case.recent_messages
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} messages in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_chars.py b/tests/bot/rules/test_chars.py
new file mode 100644
index 000000000..7cc36f49e
--- /dev/null
+++ b/tests/bot/rules/test_chars.py
@@ -0,0 +1,66 @@
+from typing import Iterable
+
+from bot.rules import chars
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage, async_test
+
+
+def make_msg(author: str, n_chars: int) -> MockMessage:
+ """Build a message with arbitrary content of `n_chars` length."""
+ return MockMessage(author=author, content="A" * n_chars)
+
+
+class CharsRuleTests(RuleTest):
+ """Tests the `chars` antispam rule."""
+
+ def setUp(self):
+ self.apply = chars.apply
+ self.config = {
+ "max": 20, # Max allowed sum of chars per user
+ "interval": 10,
+ }
+
+ @async_test
+ async def test_allows_messages_within_limit(self):
+ """Cases with a total amount of chars within limit."""
+ cases = (
+ [make_msg("bob", 0)],
+ [make_msg("bob", 20)],
+ [make_msg("bob", 15), make_msg("alice", 15)],
+ )
+
+ await self.run_allowed(cases)
+
+ @async_test
+ async def test_disallows_messages_beyond_limit(self):
+ """Cases where the total amount of chars exceeds the limit, triggering the rule."""
+ cases = (
+ DisallowedCase(
+ [make_msg("bob", 21)],
+ ("bob",),
+ 21,
+ ),
+ DisallowedCase(
+ [make_msg("bob", 15), make_msg("bob", 15)],
+ ("bob",),
+ 30,
+ ),
+ DisallowedCase(
+ [make_msg("alice", 15), make_msg("bob", 20), make_msg("alice", 15)],
+ ("alice",),
+ 30,
+ ),
+ )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_message = case.recent_messages[0]
+ return tuple(
+ msg
+ for msg in case.recent_messages
+ if msg.author == last_message.author
+ )
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} characters in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_discord_emojis.py b/tests/bot/rules/test_discord_emojis.py
new file mode 100644
index 000000000..0239b0b00
--- /dev/null
+++ b/tests/bot/rules/test_discord_emojis.py
@@ -0,0 +1,54 @@
+from typing import Iterable
+
+from bot.rules import discord_emojis
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage, async_test
+
+discord_emoji = "<:abcd:1234>" # Discord emojis follow the format <:name:id>
+
+
+def make_msg(author: str, n_emojis: int) -> MockMessage:
+ """Build a MockMessage instance with content containing `n_emojis` arbitrary emojis."""
+ return MockMessage(author=author, content=discord_emoji * n_emojis)
+
+
+class DiscordEmojisRuleTests(RuleTest):
+ """Tests for the `discord_emojis` antispam rule."""
+
+ def setUp(self):
+ self.apply = discord_emojis.apply
+ self.config = {"max": 2, "interval": 10}
+
+ @async_test
+ async def test_allows_messages_within_limit(self):
+ """Cases with a total amount of discord emojis within limit."""
+ cases = (
+ [make_msg("bob", 2)],
+ [make_msg("alice", 1), make_msg("bob", 2), make_msg("alice", 1)],
+ )
+
+ await self.run_allowed(cases)
+
+ @async_test
+ async def test_disallows_messages_beyond_limit(self):
+ """Cases with more than the allowed amount of discord emojis."""
+ cases = (
+ DisallowedCase(
+ [make_msg("bob", 3)],
+ ("bob",),
+ 3,
+ ),
+ DisallowedCase(
+ [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)],
+ ("alice",),
+ 4,
+ ),
+ )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ return tuple(msg for msg in case.recent_messages if msg.author in case.culprits)
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} emojis in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_duplicates.py b/tests/bot/rules/test_duplicates.py
new file mode 100644
index 000000000..59e0fb6ef
--- /dev/null
+++ b/tests/bot/rules/test_duplicates.py
@@ -0,0 +1,66 @@
+from typing import Iterable
+
+from bot.rules import duplicates
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage, async_test
+
+
+def make_msg(author: str, content: str) -> MockMessage:
+ """Give a MockMessage instance with `author` and `content` attrs."""
+ return MockMessage(author=author, content=content)
+
+
+class DuplicatesRuleTests(RuleTest):
+ """Tests the `duplicates` antispam rule."""
+
+ def setUp(self):
+ self.apply = duplicates.apply
+ self.config = {"max": 2, "interval": 10}
+
+ @async_test
+ async def test_allows_messages_within_limit(self):
+ """Cases which do not violate the rule."""
+ cases = (
+ [make_msg("alice", "A"), make_msg("alice", "A")],
+ [make_msg("alice", "A"), make_msg("alice", "B"), make_msg("alice", "C")], # Non-duplicate
+ [make_msg("alice", "A"), make_msg("bob", "A"), make_msg("alice", "A")], # Different author
+ )
+
+ await self.run_allowed(cases)
+
+ @async_test
+ async def test_disallows_messages_beyond_limit(self):
+ """Cases with too many duplicate messages from the same author."""
+ cases = (
+ DisallowedCase(
+ [make_msg("alice", "A"), make_msg("alice", "A"), make_msg("alice", "A")],
+ ("alice",),
+ 3,
+ ),
+ DisallowedCase(
+ [make_msg("bob", "A"), make_msg("alice", "A"), make_msg("bob", "A"), make_msg("bob", "A")],
+ ("bob",),
+ 3, # 4 duplicate messages, but only 3 from bob
+ ),
+ DisallowedCase(
+ [make_msg("bob", "A"), make_msg("bob", "B"), make_msg("bob", "A"), make_msg("bob", "A")],
+ ("bob",),
+ 3, # 4 message from bob, but only 3 duplicates
+ ),
+ )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_message = case.recent_messages[0]
+ return tuple(
+ msg
+ for msg in case.recent_messages
+ if (
+ msg.author == last_message.author
+ and msg.content == last_message.content
+ )
+ )
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} duplicated messages in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_links.py b/tests/bot/rules/test_links.py
index 02a5d5501..3c3f90e5f 100644
--- a/tests/bot/rules/test_links.py
+++ b/tests/bot/rules/test_links.py
@@ -1,26 +1,21 @@
-import unittest
-from typing import List, NamedTuple, Tuple
+from typing import Iterable
from bot.rules import links
+from tests.bot.rules import DisallowedCase, RuleTest
from tests.helpers import MockMessage, async_test
-class Case(NamedTuple):
- recent_messages: List[MockMessage]
- culprit: Tuple[str]
- total_links: int
-
-
-def msg(author: str, total_links: int) -> MockMessage:
+def make_msg(author: str, total_links: int) -> MockMessage:
"""Makes a message with `total_links` links."""
content = " ".join(["https://pydis.com"] * total_links)
return MockMessage(author=author, content=content)
-class LinksTests(unittest.TestCase):
+class LinksTests(RuleTest):
"""Tests applying the `links` rule."""
def setUp(self):
+ self.apply = links.apply
self.config = {
"max": 2,
"interval": 10
@@ -30,68 +25,45 @@ class LinksTests(unittest.TestCase):
async def test_links_within_limit(self):
"""Messages with an allowed amount of links."""
cases = (
- [msg("bob", 0)],
- [msg("bob", 2)],
- [msg("bob", 3)], # Filter only applies if len(messages_with_links) > 1
- [msg("bob", 1), msg("bob", 1)],
- [msg("bob", 2), msg("alice", 2)] # Only messages from latest author count
+ [make_msg("bob", 0)],
+ [make_msg("bob", 2)],
+ [make_msg("bob", 3)], # Filter only applies if len(messages_with_links) > 1
+ [make_msg("bob", 1), make_msg("bob", 1)],
+ [make_msg("bob", 2), make_msg("alice", 2)] # Only messages from latest author count
)
- for recent_messages in cases:
- last_message = recent_messages[0]
-
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- config=self.config
- ):
- self.assertIsNone(
- await links.apply(last_message, recent_messages, self.config)
- )
+ await self.run_allowed(cases)
@async_test
async def test_links_exceeding_limit(self):
"""Messages with a a higher than allowed amount of links."""
cases = (
- Case(
- [msg("bob", 1), msg("bob", 2)],
+ DisallowedCase(
+ [make_msg("bob", 1), make_msg("bob", 2)],
("bob",),
3
),
- Case(
- [msg("alice", 1), msg("alice", 1), msg("alice", 1)],
+ DisallowedCase(
+ [make_msg("alice", 1), make_msg("alice", 1), make_msg("alice", 1)],
("alice",),
3
),
- Case(
- [msg("alice", 2), msg("bob", 3), msg("alice", 1)],
+ DisallowedCase(
+ [make_msg("alice", 2), make_msg("bob", 3), make_msg("alice", 1)],
("alice",),
3
)
)
- for recent_messages, culprit, total_links in cases:
- last_message = recent_messages[0]
- relevant_messages = tuple(
- msg
- for msg in recent_messages
- if msg.author == last_message.author
- )
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_message = case.recent_messages[0]
+ return tuple(
+ msg
+ for msg in case.recent_messages
+ if msg.author == last_message.author
+ )
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- relevant_messages=relevant_messages,
- culprit=culprit,
- total_links=total_links,
- config=self.config
- ):
- desired_output = (
- f"sent {total_links} links in {self.config['interval']}s",
- culprit,
- relevant_messages
- )
- self.assertTupleEqual(
- await links.apply(last_message, recent_messages, self.config),
- desired_output
- )
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} links in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py
index ad49ead32..ebcdabac6 100644
--- a/tests/bot/rules/test_mentions.py
+++ b/tests/bot/rules/test_mentions.py
@@ -1,95 +1,67 @@
-import unittest
-from typing import List, NamedTuple, Tuple
+from typing import Iterable
from bot.rules import mentions
+from tests.bot.rules import DisallowedCase, RuleTest
from tests.helpers import MockMessage, async_test
-class Case(NamedTuple):
- recent_messages: List[MockMessage]
- culprit: Tuple[str]
- total_mentions: int
-
-
-def msg(author: str, total_mentions: int) -> MockMessage:
+def make_msg(author: str, total_mentions: int) -> MockMessage:
"""Makes a message with `total_mentions` mentions."""
return MockMessage(author=author, mentions=list(range(total_mentions)))
-class TestMentions(unittest.TestCase):
+class TestMentions(RuleTest):
"""Tests applying the `mentions` antispam rule."""
def setUp(self):
+ self.apply = mentions.apply
self.config = {
"max": 2,
- "interval": 10
+ "interval": 10,
}
@async_test
async def test_mentions_within_limit(self):
"""Messages with an allowed amount of mentions."""
cases = (
- [msg("bob", 0)],
- [msg("bob", 2)],
- [msg("bob", 1), msg("bob", 1)],
- [msg("bob", 1), msg("alice", 2)]
+ [make_msg("bob", 0)],
+ [make_msg("bob", 2)],
+ [make_msg("bob", 1), make_msg("bob", 1)],
+ [make_msg("bob", 1), make_msg("alice", 2)],
)
- for recent_messages in cases:
- last_message = recent_messages[0]
-
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- config=self.config
- ):
- self.assertIsNone(
- await mentions.apply(last_message, recent_messages, self.config)
- )
+ await self.run_allowed(cases)
@async_test
async def test_mentions_exceeding_limit(self):
"""Messages with a higher than allowed amount of mentions."""
cases = (
- Case(
- [msg("bob", 3)],
+ DisallowedCase(
+ [make_msg("bob", 3)],
("bob",),
- 3
+ 3,
),
- Case(
- [msg("alice", 2), msg("alice", 0), msg("alice", 1)],
+ DisallowedCase(
+ [make_msg("alice", 2), make_msg("alice", 0), make_msg("alice", 1)],
("alice",),
- 3
+ 3,
),
- Case(
- [msg("bob", 2), msg("alice", 3), msg("bob", 2)],
+ DisallowedCase(
+ [make_msg("bob", 2), make_msg("alice", 3), make_msg("bob", 2)],
("bob",),
- 4
+ 4,
)
)
- for recent_messages, culprit, total_mentions in cases:
- last_message = recent_messages[0]
- relevant_messages = tuple(
- msg
- for msg in recent_messages
- if msg.author == last_message.author
- )
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_message = case.recent_messages[0]
+ return tuple(
+ msg
+ for msg in case.recent_messages
+ if msg.author == last_message.author
+ )
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- relevant_messages=relevant_messages,
- culprit=culprit,
- total_mentions=total_mentions,
- cofig=self.config
- ):
- desired_output = (
- f"sent {total_mentions} mentions in {self.config['interval']}s",
- culprit,
- relevant_messages
- )
- self.assertTupleEqual(
- await mentions.apply(last_message, recent_messages, self.config),
- desired_output
- )
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} mentions in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_newlines.py b/tests/bot/rules/test_newlines.py
new file mode 100644
index 000000000..d61c4609d
--- /dev/null
+++ b/tests/bot/rules/test_newlines.py
@@ -0,0 +1,105 @@
+from typing import Iterable, List
+
+from bot.rules import newlines
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage, async_test
+
+
+def make_msg(author: str, newline_groups: List[int]) -> MockMessage:
+ """Init a MockMessage instance with `author` and content configured by `newline_groups".
+
+ Configure content by passing a list of ints, where each int `n` will generate
+ a separate group of `n` newlines.
+
+ Example:
+ newline_groups=[3, 1, 2] -> content="\n\n\n \n \n\n"
+ """
+ content = " ".join("\n" * n for n in newline_groups)
+ return MockMessage(author=author, content=content)
+
+
+class TotalNewlinesRuleTests(RuleTest):
+ """Tests the `newlines` antispam rule against allowed cases and total newline count violations."""
+
+ def setUp(self):
+ self.apply = newlines.apply
+ self.config = {
+ "max": 5, # Max sum of newlines in relevant messages
+ "max_consecutive": 3, # Max newlines in one group, in one message
+ "interval": 10,
+ }
+
+ @async_test
+ async def test_allows_messages_within_limit(self):
+ """Cases which do not violate the rule."""
+ cases = (
+ [make_msg("alice", [])], # Single message with no newlines
+ [make_msg("alice", [1, 2]), make_msg("alice", [1, 1])], # 5 newlines in 2 messages
+ [make_msg("alice", [2, 2, 1]), make_msg("bob", [2, 3])], # 5 newlines from each author
+ [make_msg("bob", [1]), make_msg("alice", [5])], # Alice breaks the rule, but only bob is relevant
+ )
+
+ await self.run_allowed(cases)
+
+ @async_test
+ async def test_disallows_messages_total(self):
+ """Cases which violate the rule by having too many newlines in total."""
+ cases = (
+ DisallowedCase( # Alice sends a total of 6 newlines (disallowed)
+ [make_msg("alice", [2, 2]), make_msg("alice", [2])],
+ ("alice",),
+ 6,
+ ),
+ DisallowedCase( # Here we test that only alice's newlines count in the sum
+ [make_msg("alice", [2, 2]), make_msg("bob", [3]), make_msg("alice", [3])],
+ ("alice",),
+ 7,
+ ),
+ )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_author = case.recent_messages[0].author
+ return tuple(msg for msg in case.recent_messages if msg.author == last_author)
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} newlines in {self.config['interval']}s"
+
+
+class GroupNewlinesRuleTests(RuleTest):
+ """
+ Tests the `newlines` antispam rule against max consecutive newline violations.
+
+ As these violations yield a different error report, they require a different
+ `get_report` implementation.
+ """
+
+ def setUp(self):
+ self.apply = newlines.apply
+ self.config = {"max": 5, "max_consecutive": 3, "interval": 10}
+
+ @async_test
+ async def test_disallows_messages_consecutive(self):
+ """Cases which violate the rule due to having too many consecutive newlines."""
+ cases = (
+ DisallowedCase( # Bob sends a group of newlines too large
+ [make_msg("bob", [4])],
+ ("bob",),
+ 4,
+ ),
+ DisallowedCase( # Alice sends 5 in total (allowed), but 4 in one group (disallowed)
+ [make_msg("alice", [1]), make_msg("alice", [4])],
+ ("alice",),
+ 4,
+ ),
+ )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_author = case.recent_messages[0].author
+ return tuple(msg for msg in case.recent_messages if msg.author == last_author)
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} consecutive newlines in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_role_mentions.py b/tests/bot/rules/test_role_mentions.py
new file mode 100644
index 000000000..b339cccf7
--- /dev/null
+++ b/tests/bot/rules/test_role_mentions.py
@@ -0,0 +1,57 @@
+from typing import Iterable
+
+from bot.rules import role_mentions
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage, async_test
+
+
+def make_msg(author: str, n_mentions: int) -> MockMessage:
+ """Build a MockMessage instance with `n_mentions` role mentions."""
+ return MockMessage(author=author, role_mentions=[None] * n_mentions)
+
+
+class RoleMentionsRuleTests(RuleTest):
+ """Tests for the `role_mentions` antispam rule."""
+
+ def setUp(self):
+ self.apply = role_mentions.apply
+ self.config = {"max": 2, "interval": 10}
+
+ @async_test
+ async def test_allows_messages_within_limit(self):
+ """Cases with a total amount of role mentions within limit."""
+ cases = (
+ [make_msg("bob", 2)],
+ [make_msg("bob", 1), make_msg("alice", 1), make_msg("bob", 1)],
+ )
+
+ await self.run_allowed(cases)
+
+ @async_test
+ async def test_disallows_messages_beyond_limit(self):
+ """Cases with more than the allowed amount of role mentions."""
+ cases = (
+ DisallowedCase(
+ [make_msg("bob", 3)],
+ ("bob",),
+ 3,
+ ),
+ DisallowedCase(
+ [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)],
+ ("alice",),
+ 4,
+ ),
+ )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_message = case.recent_messages[0]
+ return tuple(
+ msg
+ for msg in case.recent_messages
+ if msg.author == last_message.author
+ )
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} role mentions in {self.config['interval']}s"
diff --git a/tests/bot/test_api.py b/tests/bot/test_api.py
index 5a88adc5c..bdfcc73e4 100644
--- a/tests/bot/test_api.py
+++ b/tests/bot/test_api.py
@@ -1,9 +1,7 @@
-import logging
import unittest
-from unittest.mock import MagicMock, patch
+from unittest.mock import MagicMock
from bot import api
-from tests.base import LoggingTestCase
from tests.helpers import async_test
@@ -34,7 +32,7 @@ class APIClientTests(unittest.TestCase):
self.assertEqual(error.response_text, "")
self.assertIs(error.response, self.error_api_response)
- def test_responde_code_error_string_representation_default_initialization(self):
+ def test_response_code_error_string_representation_default_initialization(self):
"""Test the string representation of `ResponseCodeError` initialized without text or json."""
error = api.ResponseCodeError(response=self.error_api_response)
self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: ")
@@ -76,61 +74,3 @@ class APIClientTests(unittest.TestCase):
response_text=text_data
)
self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: {text_data}")
-
-
-class LoggingHandlerTests(LoggingTestCase):
- """Tests the bot's API Log Handler."""
-
- @classmethod
- def setUpClass(cls):
- cls.debug_log_record = logging.LogRecord(
- name='my.logger', level=logging.DEBUG,
- pathname='my/logger.py', lineno=666,
- msg="Lemon wins", args=(),
- exc_info=None
- )
-
- cls.trace_log_record = logging.LogRecord(
- name='my.logger', level=logging.TRACE,
- pathname='my/logger.py', lineno=666,
- msg="This will not be logged", args=(),
- exc_info=None
- )
-
- def setUp(self):
- self.log_handler = api.APILoggingHandler(None)
-
- def test_emit_appends_to_queue_with_stopped_event_loop(self):
- """Test if `APILoggingHandler.emit` appends to queue when the event loop is not running."""
- with patch("bot.api.APILoggingHandler.ship_off") as ship_off:
- # Patch `ship_off` to ease testing against the return value of this coroutine.
- ship_off.return_value = 42
- self.log_handler.emit(self.debug_log_record)
-
- self.assertListEqual(self.log_handler.queue, [42])
-
- def test_emit_ignores_less_than_debug(self):
- """`APILoggingHandler.emit` should not queue logs with a log level lower than DEBUG."""
- self.log_handler.emit(self.trace_log_record)
- self.assertListEqual(self.log_handler.queue, [])
-
- def test_schedule_queued_tasks_for_empty_queue(self):
- """`APILoggingHandler` should not schedule anything when the queue is empty."""
- with self.assertNotLogs(level=logging.DEBUG):
- self.log_handler.schedule_queued_tasks()
-
- def test_schedule_queued_tasks_for_nonempty_queue(self):
- """`APILoggingHandler` should schedule logs when the queue is not empty."""
- log = logging.getLogger("bot.api")
-
- with self.assertLogs(logger=log, level=logging.DEBUG) as logs, patch('asyncio.create_task') as create_task:
- self.log_handler.queue = [555]
- self.log_handler.schedule_queued_tasks()
- self.assertListEqual(self.log_handler.queue, [])
- create_task.assert_called_once_with(555)
-
- [record] = logs.records
- self.assertEqual(record.message, "Scheduled 1 pending logging tasks.")
- self.assertEqual(record.levelno, logging.DEBUG)
- self.assertEqual(record.name, 'bot.api')
- self.assertIn('via_handler', record.__dict__)
diff --git a/tests/bot/test_utils.py b/tests/bot/test_utils.py
index 58ae2a81a..d7bcc3ba6 100644
--- a/tests/bot/test_utils.py
+++ b/tests/bot/test_utils.py
@@ -35,18 +35,3 @@ class CaseInsensitiveDictTests(unittest.TestCase):
instance = utils.CaseInsensitiveDict()
instance.update({'FOO': 'bar'})
self.assertEqual(instance['foo'], 'bar')
-
-
-class ChunkTests(unittest.TestCase):
- """Tests the `chunk` method."""
-
- def test_empty_chunking(self):
- """Tests chunking on an empty iterable."""
- generator = utils.chunks(iterable=[], size=5)
- self.assertEqual(list(generator), [])
-
- def test_list_chunking(self):
- """Tests chunking a non-empty list."""
- iterable = [1, 2, 3, 4, 5]
- generator = utils.chunks(iterable=iterable, size=2)
- self.assertEqual(list(generator), [[1, 2], [3, 4], [5]])
diff --git a/tests/helpers.py b/tests/helpers.py
index 6aee8623f..6f50f6ae3 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -12,6 +12,7 @@ from typing import Any, Iterable, Optional
import discord
from discord.ext.commands import Context
+from bot.api import APIClient
from bot.bot import Bot
@@ -281,9 +282,21 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):
information, see the `MockGuild` docstring.
"""
def __init__(self, **kwargs) -> None:
- default_kwargs = {'id': next(self.discord_id), 'name': 'role', 'position': 1}
+ default_kwargs = {
+ 'id': next(self.discord_id),
+ 'name': 'role',
+ 'position': 1,
+ 'colour': discord.Colour(0xdeadbf),
+ 'permissions': discord.Permissions(),
+ }
super().__init__(spec_set=role_instance, **collections.ChainMap(kwargs, default_kwargs))
+ if isinstance(self.colour, int):
+ self.colour = discord.Colour(self.colour)
+
+ if isinstance(self.permissions, int):
+ self.permissions = discord.Permissions(self.permissions)
+
if 'mention' not in kwargs:
self.mention = f'&{self.name}'
@@ -336,6 +349,18 @@ class MockUser(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):
self.mention = f"@{self.name}"
+class MockAPIClient(CustomMockMixin, unittest.mock.MagicMock):
+ """
+ A MagicMock subclass to mock APIClient objects.
+
+ Instances of this class will follow the specifications of `bot.api.APIClient` instances.
+ For more information, see the `MockGuild` docstring.
+ """
+
+ def __init__(self, **kwargs) -> None:
+ super().__init__(spec_set=APIClient, **kwargs)
+
+
# Create a Bot instance to get a realistic MagicMock of `discord.ext.commands.Bot`
bot_instance = Bot(command_prefix=unittest.mock.MagicMock())
bot_instance.http_session = None
@@ -352,6 +377,7 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock):
def __init__(self, **kwargs) -> None:
super().__init__(spec_set=bot_instance, **kwargs)
+ self.api_client = MockAPIClient()
# self.wait_for is *not* a coroutine function, but returns a coroutine nonetheless and
# and should therefore be awaited. (The documentation calls it a coroutine as well, which
@@ -515,6 +541,7 @@ class MockReaction(CustomMockMixin, unittest.mock.MagicMock):
self.emoji = kwargs.get('emoji', MockEmoji())
self.message = kwargs.get('message', MockMessage())
self.users = AsyncIteratorMock(kwargs.get('users', []))
+ self.__str__.return_value = str(self.emoji)
webhook_instance = discord.Webhook(data=unittest.mock.MagicMock(), adapter=unittest.mock.MagicMock())
diff --git a/tox.ini b/tox.ini
index d14819d57..b8293a3b6 100644
--- a/tox.ini
+++ b/tox.ini
@@ -3,7 +3,7 @@ max-line-length=120
docstring-convention=all
import-order-style=pycharm
application_import_names=bot,tests
-exclude=.cache,.venv,constants.py
+exclude=.cache,.venv,.git,constants.py
ignore=
B311,W503,E226,S311,T000
# Missing Docstrings
@@ -15,5 +15,5 @@ ignore=
# Docstring Content
D400,D401,D402,D404,D405,D406,D407,D408,D409,D410,D411,D412,D413,D414,D416,D417
# Type Annotations
- TYP002,TYP003,TYP101,TYP102,TYP204,TYP206
-per-file-ignores=tests/*:D,TYP
+ ANN002,ANN003,ANN101,ANN102,ANN204,ANN206
+per-file-ignores=tests/*:D,ANN