aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Mark <[email protected]>2020-03-03 08:52:45 -0800
committerGravatar GitHub <[email protected]>2020-03-03 08:52:45 -0800
commitbcf252f0fb5b2188810b228b9c9c0777f45d0c8c (patch)
tree4d8fdbdfd380724f6e699bde445b7b5d144ce3c1
parentFix typo in comment (diff)
parentAdding helpers to the Filtering whitelist (diff)
Merge branch 'master' into bug/backend/b748/resolver-in-coro
-rw-r--r--Dockerfile9
-rw-r--r--Pipfile19
-rw-r--r--Pipfile.lock301
-rw-r--r--azure-pipelines.yml4
-rw-r--r--bot/__init__.py30
-rw-r--r--bot/__main__.py11
-rw-r--r--bot/cogs/antimalware.py10
-rw-r--r--bot/cogs/bot.py5
-rw-r--r--bot/cogs/clean.py2
-rw-r--r--bot/cogs/config_verifier.py40
-rw-r--r--bot/cogs/defcon.py14
-rw-r--r--bot/cogs/error_handler.py258
-rw-r--r--bot/cogs/eval.py4
-rw-r--r--bot/cogs/extensions.py2
-rw-r--r--bot/cogs/free.py2
-rw-r--r--bot/cogs/help.py2
-rw-r--r--bot/cogs/information.py6
-rw-r--r--bot/cogs/jams.py6
-rw-r--r--bot/cogs/logging.py2
-rw-r--r--bot/cogs/moderation/infractions.py2
-rw-r--r--bot/cogs/moderation/management.py8
-rw-r--r--bot/cogs/moderation/modlog.py22
-rw-r--r--bot/cogs/moderation/scheduler.py11
-rw-r--r--bot/cogs/moderation/superstarify.py2
-rw-r--r--bot/cogs/reddit.py9
-rw-r--r--bot/cogs/reminders.py16
-rw-r--r--bot/cogs/snekbox.py135
-rw-r--r--bot/cogs/sync/syncers.py30
-rw-r--r--bot/cogs/tags.py5
-rw-r--r--bot/cogs/utils.py2
-rw-r--r--bot/cogs/verification.py11
-rw-r--r--bot/constants.py57
-rw-r--r--bot/converters.py26
-rw-r--r--bot/utils/__init__.py12
-rw-r--r--bot/utils/scheduling.py83
-rw-r--r--config-default.yml244
-rw-r--r--tests/base.py12
-rw-r--r--tests/bot/cogs/sync/test_base.py49
-rw-r--r--tests/bot/cogs/sync/test_cog.py49
-rw-r--r--tests/bot/cogs/sync/test_roles.py12
-rw-r--r--tests/bot/cogs/sync/test_users.py13
-rw-r--r--tests/bot/cogs/test_duck_pond.py35
-rw-r--r--tests/bot/cogs/test_information.py44
-rw-r--r--tests/bot/cogs/test_snekbox.py354
-rw-r--r--tests/bot/cogs/test_token_remover.py4
-rw-r--r--tests/bot/rules/__init__.py6
-rw-r--r--tests/bot/rules/test_attachments.py4
-rw-r--r--tests/bot/rules/test_burst.py4
-rw-r--r--tests/bot/rules/test_burst_shared.py4
-rw-r--r--tests/bot/rules/test_chars.py4
-rw-r--r--tests/bot/rules/test_discord_emojis.py4
-rw-r--r--tests/bot/rules/test_duplicates.py4
-rw-r--r--tests/bot/rules/test_links.py4
-rw-r--r--tests/bot/rules/test_mentions.py4
-rw-r--r--tests/bot/rules/test_newlines.py5
-rw-r--r--tests/bot/rules/test_role_mentions.py4
-rw-r--r--tests/bot/test_api.py4
-rw-r--r--tests/bot/test_converters.py2
-rw-r--r--tests/bot/test_utils.py15
-rw-r--r--tests/bot/utils/test_time.py5
-rw-r--r--tests/helpers.py215
-rw-r--r--tests/test_base.py20
-rw-r--r--tests/test_helpers.py71
-rw-r--r--tests/utils/test_time.py62
64 files changed, 1367 insertions, 1048 deletions
diff --git a/Dockerfile b/Dockerfile
index 271c25050..06a538b2a 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,4 +1,4 @@
-FROM python:3.7-slim
+FROM python:3.8-slim
# Set pip to have cleaner logs and no saved cache
ENV PIP_NO_CACHE_DIR=false \
@@ -9,12 +9,15 @@ ENV PIP_NO_CACHE_DIR=false \
# Install pipenv
RUN pip install -U pipenv
-# Copy project files into working directory
+# Create the working directory
WORKDIR /bot
-COPY . .
# Install project dependencies
+COPY Pipfile* ./
RUN pipenv install --system --deploy
+# Copy the source code in last to optimize rebuilding the image
+COPY . .
+
ENTRYPOINT ["python3"]
CMD ["-m", "bot"]
diff --git a/Pipfile b/Pipfile
index 400e64c18..64760f9dd 100644
--- a/Pipfile
+++ b/Pipfile
@@ -4,7 +4,7 @@ verify_ssl = true
name = "pypi"
[packages]
-discord-py = "~=1.3.1"
+discord-py = "~=1.3.2"
aiodns = "~=2.0"
aiohttp = "~=3.5"
sphinx = "~=2.2"
@@ -16,27 +16,28 @@ aio-pika = "~=6.1"
python-dateutil = "~=2.8"
deepdiff = "~=4.0"
requests = "~=2.22"
-more_itertools = "~=7.2"
-urllib3 = ">=1.24.2,<1.25"
+more_itertools = "~=8.2"
sentry-sdk = "~=0.14"
+coloredlogs = "~=14.0"
+colorama = {version = "~=0.4.3", sys_platform = "== 'win32'"}
[dev-packages]
-coverage = "~=4.5"
+coverage = "~=5.0"
flake8 = "~=3.7"
flake8-annotations = "~=2.0"
-flake8-bugbear = "~=19.8"
+flake8-bugbear = "~=20.1"
flake8-docstrings = "~=1.4"
flake8-import-order = "~=0.18"
flake8-string-format = "~=0.2"
-flake8-tidy-imports = "~=2.0"
+flake8-tidy-imports = "~=4.0"
flake8-todo = "~=0.7"
-pre-commit = "~=1.18"
+pre-commit = "~=2.1"
safety = "~=1.8"
-unittest-xml-reporting = "~=2.5"
+unittest-xml-reporting = "~=3.0"
dodgy = "~=0.1"
[requires]
-python_version = "3.7"
+python_version = "3.8"
[scripts]
start = "python -m bot"
diff --git a/Pipfile.lock b/Pipfile.lock
index fa29bf995..9953aab40 100644
--- a/Pipfile.lock
+++ b/Pipfile.lock
@@ -1,11 +1,11 @@
{
"_meta": {
"hash": {
- "sha256": "c7706a61eb96c06d073898018ea2dbcf5bd3b15d007496e2d60120a65647f31e"
+ "sha256": "fae6dcdb6a5ebf27e8ea5044f4ca2ab854774d17affb5fd64ac85f8d0ae71187"
},
"pipfile-spec": 6,
"requires": {
- "python_version": "3.7"
+ "python_version": "3.8"
},
"sources": [
{
@@ -140,6 +140,23 @@
],
"version": "==3.0.4"
},
+ "colorama": {
+ "hashes": [
+ "sha256:7d73d2a99753107a36ac6b455ee49046802e59d9d076ef8e47b61499fa29afff",
+ "sha256:e96da0d330793e2cb9485e9ddfd918d456036c7149416295932478192f4436a1"
+ ],
+ "index": "pypi",
+ "markers": "sys_platform == 'win32'",
+ "version": "==0.4.3"
+ },
+ "coloredlogs": {
+ "hashes": [
+ "sha256:346f58aad6afd48444c2468618623638dadab76e4e70d5e10822676f2d32226a",
+ "sha256:a1fab193d2053aa6c0a97608c4342d031f1f93a3d1218432c59322441d31a505"
+ ],
+ "index": "pypi",
+ "version": "==14.0"
+ },
"deepdiff": {
"hashes": [
"sha256:b3fa588d1eac7fa318ec1fb4f2004568e04cb120a1989feda8e5e7164bcbf07a",
@@ -150,10 +167,10 @@
},
"discord-py": {
"hashes": [
- "sha256:8bfe5628d31771744000f19135c386c74ac337479d7282c26cc1627b9d31f360"
+ "sha256:7424be26b07b37ecad4404d9383d685995a0e0b3df3f9c645bdd3a4d977b83b4"
],
"index": "pypi",
- "version": "==1.3.1"
+ "version": "==1.3.2"
},
"docutils": {
"hashes": [
@@ -170,6 +187,13 @@
"index": "pypi",
"version": "==0.18.0"
},
+ "humanfriendly": {
+ "hashes": [
+ "sha256:cbe04ecf964ccb951a578f396091f258448ca4b4b4c6d4b6194f48ef458fe991",
+ "sha256:e8e2e4524409e55d5c5cbbb4c555a0c0a9599d5e8f74d0ce1ac504ba51ad1cd2"
+ ],
+ "version": "==7.2"
+ },
"idna": {
"hashes": [
"sha256:7588d1c14ae4c77d74036e8c22ff447b26d0fde8f007354fd48a7814db15b7cb",
@@ -271,33 +295,33 @@
},
"more-itertools": {
"hashes": [
- "sha256:409cd48d4db7052af495b09dec721011634af3753ae1ef92d2b32f73a745f832",
- "sha256:92b8c4b06dac4f0611c0729b2f2ede52b2e1bac1ab48f089c7ddc12e26bb60c4"
+ "sha256:5dd8bcf33e5f9513ffa06d5ad33d78f31e1931ac9a18f33d37e77a180d393a7c",
+ "sha256:b1ddb932186d8a6ac451e1d95844b382f55e12686d51ca0c68b6f61f2ab7a507"
],
"index": "pypi",
- "version": "==7.2.0"
+ "version": "==8.2.0"
},
"multidict": {
"hashes": [
- "sha256:13f3ebdb5693944f52faa7b2065b751cb7e578b8dd0a5bb8e4ab05ad0188b85e",
- "sha256:26502cefa86d79b86752e96639352c7247846515c864d7c2eb85d036752b643c",
- "sha256:4fba5204d32d5c52439f88437d33ad14b5f228e25072a192453f658bddfe45a7",
- "sha256:527124ef435f39a37b279653ad0238ff606b58328ca7989a6df372fd75d7fe26",
- "sha256:5414f388ffd78c57e77bd253cf829373721f450613de53dc85a08e34d806e8eb",
- "sha256:5eee66f882ab35674944dfa0d28b57fa51e160b4dce0ce19e47f495fdae70703",
- "sha256:63810343ea07f5cd86ba66ab66706243a6f5af075eea50c01e39b4ad6bc3c57a",
- "sha256:6bd10adf9f0d6a98ccc792ab6f83d18674775986ba9bacd376b643fe35633357",
- "sha256:83c6ddf0add57c6b8a7de0bc7e2d656be3eefeff7c922af9a9aae7e49f225625",
- "sha256:93166e0f5379cf6cd29746989f8a594fa7204dcae2e9335ddba39c870a287e1c",
- "sha256:9a7b115ee0b9b92d10ebc246811d8f55d0c57e82dbb6a26b23c9a9a6ad40ce0c",
- "sha256:a38baa3046cce174a07a59952c9f876ae8875ef3559709639c17fdf21f7b30dd",
- "sha256:a6d219f49821f4b2c85c6d426346a5d84dab6daa6f85ca3da6c00ed05b54022d",
- "sha256:a8ed33e8f9b67e3b592c56567135bb42e7e0e97417a4b6a771e60898dfd5182b",
- "sha256:d7d428488c67b09b26928950a395e41cc72bb9c3d5abfe9f0521940ee4f796d4",
- "sha256:dcfed56aa085b89d644af17442cdc2debaa73388feba4b8026446d168ca8dad7",
- "sha256:f29b885e4903bd57a7789f09fe9d60b6475a6c1a4c0eca874d8558f00f9d4b51"
- ],
- "version": "==4.7.4"
+ "sha256:317f96bc0950d249e96d8d29ab556d01dd38888fbe68324f46fd834b430169f1",
+ "sha256:42f56542166040b4474c0c608ed051732033cd821126493cf25b6c276df7dd35",
+ "sha256:4b7df040fb5fe826d689204f9b544af469593fb3ff3a069a6ad3409f742f5928",
+ "sha256:544fae9261232a97102e27a926019100a9db75bec7b37feedd74b3aa82f29969",
+ "sha256:620b37c3fea181dab09267cd5a84b0f23fa043beb8bc50d8474dd9694de1fa6e",
+ "sha256:6e6fef114741c4d7ca46da8449038ec8b1e880bbe68674c01ceeb1ac8a648e78",
+ "sha256:7774e9f6c9af3f12f296131453f7b81dabb7ebdb948483362f5afcaac8a826f1",
+ "sha256:85cb26c38c96f76b7ff38b86c9d560dea10cf3459bb5f4caf72fc1bb932c7136",
+ "sha256:a326f4240123a2ac66bb163eeba99578e9d63a8654a59f4688a79198f9aa10f8",
+ "sha256:ae402f43604e3b2bc41e8ea8b8526c7fa7139ed76b0d64fc48e28125925275b2",
+ "sha256:aee283c49601fa4c13adc64c09c978838a7e812f85377ae130a24d7198c0331e",
+ "sha256:b51249fdd2923739cd3efc95a3d6c363b67bbf779208e9f37fd5e68540d1a4d4",
+ "sha256:bb519becc46275c594410c6c28a8a0adc66fe24fef154a9addea54c1adb006f5",
+ "sha256:c2c37185fb0af79d5c117b8d2764f4321eeb12ba8c141a95d0aa8c2c1d0a11dd",
+ "sha256:dc561313279f9d05a3d0ffa89cd15ae477528ea37aa9795c4654588a3287a9ab",
+ "sha256:e439c9a10a95cb32abd708bb8be83b2134fa93790a4fb0535ca36db3dda94d20",
+ "sha256:fc3b4adc2ee8474cb3cd2a155305d5f8eda0a9c91320f83e55748e1fcb68f8e3"
+ ],
+ "version": "==4.7.5"
},
"ordered-set": {
"hashes": [
@@ -355,7 +379,8 @@
},
"pycparser": {
"hashes": [
- "sha256:a988718abfad80b6b157acce7bf130a30876d27603738ac39f140993246b25b3"
+ "sha256:a988718abfad80b6b157acce7bf130a30876d27603738ac39f140993246b25b3",
+ "sha256:fd64020e8a5e0369de455adf9f22795a90fdb74e6bb999e9a13fd26b54f533ef"
],
"version": "==2.19"
},
@@ -373,6 +398,13 @@
],
"version": "==2.4.6"
},
+ "pyreadline": {
+ "hashes": [
+ "sha256:4530592fc2e85b25b1a9f79664433da09237c1a270e4d78ea5aa3a2c7229e2d1"
+ ],
+ "markers": "sys_platform == 'win32'",
+ "version": "==2.1"
+ },
"python-dateutil": {
"hashes": [
"sha256:73ebfe9dbf22e832286dafa60473e4cd239f8592f699aa5adaf10050e6e1823c",
@@ -415,11 +447,11 @@
},
"sentry-sdk": {
"hashes": [
- "sha256:b06dd27391fd11fb32f84fe054e6a64736c469514a718a99fb5ce1dff95d6b28",
- "sha256:e023da07cfbead3868e1e2ba994160517885a32dfd994fc455b118e37989479b"
+ "sha256:480eee754e60bcae983787a9a13bc8f155a111aef199afaa4f289d6a76aa622a",
+ "sha256:a920387dc3ee252a66679d0afecd34479fb6fc52c2bc20763793ed69e5b0dcc0"
],
"index": "pypi",
- "version": "==0.14.1"
+ "version": "==0.14.2"
},
"six": {
"hashes": [
@@ -437,39 +469,39 @@
},
"soupsieve": {
"hashes": [
- "sha256:bdb0d917b03a1369ce964056fc195cfdff8819c40de04695a80bc813c3cfa1f5",
- "sha256:e2c1c5dee4a1c36bcb790e0fabd5492d874b8ebd4617622c4f6a731701060dda"
+ "sha256:e914534802d7ffd233242b785229d5ba0766a7f487385e3f714446a07bf540ae",
+ "sha256:fcd71e08c0aee99aca1b73f45478549ee7e7fc006d51b37bec9e9def7dc22b69"
],
- "version": "==1.9.5"
+ "version": "==2.0"
},
"sphinx": {
"hashes": [
- "sha256:525527074f2e0c2585f68f73c99b4dc257c34bbe308b27f5f8c7a6e20642742f",
- "sha256:543d39db5f82d83a5c1aa0c10c88f2b6cff2da3e711aa849b2c627b4b403bbd9"
+ "sha256:776ff8333181138fae52df65be733127539623bb46cc692e7fa0fcfc80d7aa88",
+ "sha256:ca762da97c3b5107cbf0ab9e11d3ec7ab8d3c31377266fd613b962ed971df709"
],
"index": "pypi",
- "version": "==2.4.2"
+ "version": "==2.4.3"
},
"sphinxcontrib-applehelp": {
"hashes": [
- "sha256:edaa0ab2b2bc74403149cb0209d6775c96de797dfd5b5e2a71981309efab3897",
- "sha256:fb8dee85af95e5c30c91f10e7eb3c8967308518e0f7488a2828ef7bc191d0d5d"
+ "sha256:806111e5e962be97c29ec4c1e7fe277bfd19e9652fb1a4392105b43e01af885a",
+ "sha256:a072735ec80e7675e3f432fcae8610ecf509c5f1869d17e2eecff44389cdbc58"
],
- "version": "==1.0.1"
+ "version": "==1.0.2"
},
"sphinxcontrib-devhelp": {
"hashes": [
- "sha256:6c64b077937330a9128a4da74586e8c2130262f014689b4b89e2d08ee7294a34",
- "sha256:9512ecb00a2b0821a146736b39f7aeb90759834b07e81e8cc23a9c70bacb9981"
+ "sha256:8165223f9a335cc1af7ffe1ed31d2871f325254c0423bc0c4c7cd1c1e4734a2e",
+ "sha256:ff7f1afa7b9642e7060379360a67e9c41e8f3121f2ce9164266f61b9f4b338e4"
],
- "version": "==1.0.1"
+ "version": "==1.0.2"
},
"sphinxcontrib-htmlhelp": {
"hashes": [
- "sha256:4670f99f8951bd78cd4ad2ab962f798f5618b17675c35c5ac3b2132a14ea8422",
- "sha256:d4fd39a65a625c9df86d7fa8a2d9f3cd8299a3a4b15db63b50aac9e161d8eff7"
+ "sha256:3c0bc24a2c41e340ac37c85ced6dafc879ab485c095b1d65d2461ac2f7cca86f",
+ "sha256:e8f5bb7e31b2dbb25b9cc435c8ab7a79787ebf7f906155729338f3156d93659b"
],
- "version": "==1.0.2"
+ "version": "==1.0.3"
},
"sphinxcontrib-jsmath": {
"hashes": [
@@ -480,25 +512,24 @@
},
"sphinxcontrib-qthelp": {
"hashes": [
- "sha256:513049b93031beb1f57d4daea74068a4feb77aa5630f856fcff2e50de14e9a20",
- "sha256:79465ce11ae5694ff165becda529a600c754f4bc459778778c7017374d4d406f"
+ "sha256:4c33767ee058b70dba89a6fc5c1892c0d57a54be67ddd3e7875a18d14cba5a72",
+ "sha256:bd9fc24bcb748a8d51fd4ecaade681350aa63009a347a8c14e637895444dfab6"
],
- "version": "==1.0.2"
+ "version": "==1.0.3"
},
"sphinxcontrib-serializinghtml": {
"hashes": [
- "sha256:c0efb33f8052c04fd7a26c0a07f1678e8512e0faec19f4aa8f2473a8b81d5227",
- "sha256:db6615af393650bf1151a6cd39120c29abaf93cc60db8c48eb2dddbfdc3a9768"
+ "sha256:eaa0eccc86e982a9b939b2b82d12cc5d013385ba5eadcc7e4fed23f4405f77bc",
+ "sha256:f242a81d423f59617a8e5cf16f5d4d74e28ee9a66f9e5b637a18082991db5a9a"
],
- "version": "==1.1.3"
+ "version": "==1.1.4"
},
"urllib3": {
"hashes": [
- "sha256:2393a695cd12afedd0dcb26fe5d50d0cf248e5a66f75dbd89a3d4eb333a61af4",
- "sha256:a637e5fae88995b256e3409dc4d52c2e2e0ba32c42a6365fee8bbd2238de3cfb"
+ "sha256:2f3db8b19923a873b3e5256dc9c2dedfa883e33d87c690d9c7913e1f40673cdc",
+ "sha256:87716c2d2a7121198ebcb7ce7cccf6ce5e9ba539041cfbaeecfb641dc0bf6acc"
],
- "index": "pypi",
- "version": "==1.24.3"
+ "version": "==1.25.8"
},
"websockets": {
"hashes": [
@@ -558,13 +589,6 @@
],
"version": "==1.4.3"
},
- "aspy.yaml": {
- "hashes": [
- "sha256:463372c043f70160a9ec950c3f1e4c3a82db5fca01d334b6bc89c7164d744bdc",
- "sha256:e7c742382eff2caed61f87a39d13f99109088e5e93f04d76eb8d4b28aa143f45"
- ],
- "version": "==1.3.0"
- },
"attrs": {
"hashes": [
"sha256:08a96c641c3a74e44eb59afb61a24f2cb9f4d7188748e76ba4bb5edfa3cb7d1c",
@@ -581,10 +605,10 @@
},
"cfgv": {
"hashes": [
- "sha256:04b093b14ddf9fd4d17c53ebfd55582d27b76ed30050193c14e560770c5360eb",
- "sha256:f22b426ed59cd2ab2b54ff96608d846c33dfb8766a67f0b4a6ce130ce244414f"
+ "sha256:1ccf53320421aeeb915275a196e23b3b8ae87dea8ac6698b1638001d4a486d53",
+ "sha256:c8e8f552ffcc6194f4e18dd4f68d9aef0c0d58ae7e7be8c82bee3c5e9edfa513"
],
- "version": "==3.0.0"
+ "version": "==3.1.0"
},
"chardet": {
"hashes": [
@@ -602,45 +626,45 @@
},
"coverage": {
"hashes": [
- "sha256:08907593569fe59baca0bf152c43f3863201efb6113ecb38ce7e97ce339805a6",
- "sha256:0be0f1ed45fc0c185cfd4ecc19a1d6532d72f86a2bac9de7e24541febad72650",
- "sha256:141f08ed3c4b1847015e2cd62ec06d35e67a3ac185c26f7635f4406b90afa9c5",
- "sha256:19e4df788a0581238e9390c85a7a09af39c7b539b29f25c89209e6c3e371270d",
- "sha256:23cc09ed395b03424d1ae30dcc292615c1372bfba7141eb85e11e50efaa6b351",
- "sha256:245388cda02af78276b479f299bbf3783ef0a6a6273037d7c60dc73b8d8d7755",
- "sha256:331cb5115673a20fb131dadd22f5bcaf7677ef758741312bee4937d71a14b2ef",
- "sha256:386e2e4090f0bc5df274e720105c342263423e77ee8826002dcffe0c9533dbca",
- "sha256:3a794ce50daee01c74a494919d5ebdc23d58873747fa0e288318728533a3e1ca",
- "sha256:60851187677b24c6085248f0a0b9b98d49cba7ecc7ec60ba6b9d2e5574ac1ee9",
- "sha256:63a9a5fc43b58735f65ed63d2cf43508f462dc49857da70b8980ad78d41d52fc",
- "sha256:6b62544bb68106e3f00b21c8930e83e584fdca005d4fffd29bb39fb3ffa03cb5",
- "sha256:6ba744056423ef8d450cf627289166da65903885272055fb4b5e113137cfa14f",
- "sha256:7494b0b0274c5072bddbfd5b4a6c6f18fbbe1ab1d22a41e99cd2d00c8f96ecfe",
- "sha256:826f32b9547c8091679ff292a82aca9c7b9650f9fda3e2ca6bf2ac905b7ce888",
- "sha256:93715dffbcd0678057f947f496484e906bf9509f5c1c38fc9ba3922893cda5f5",
- "sha256:9a334d6c83dfeadae576b4d633a71620d40d1c379129d587faa42ee3e2a85cce",
- "sha256:af7ed8a8aa6957aac47b4268631fa1df984643f07ef00acd374e456364b373f5",
- "sha256:bf0a7aed7f5521c7ca67febd57db473af4762b9622254291fbcbb8cd0ba5e33e",
- "sha256:bf1ef9eb901113a9805287e090452c05547578eaab1b62e4ad456fcc049a9b7e",
- "sha256:c0afd27bc0e307a1ffc04ca5ec010a290e49e3afbe841c5cafc5c5a80ecd81c9",
- "sha256:dd579709a87092c6dbee09d1b7cfa81831040705ffa12a1b248935274aee0437",
- "sha256:df6712284b2e44a065097846488f66840445eb987eb81b3cc6e4149e7b6982e1",
- "sha256:e07d9f1a23e9e93ab5c62902833bf3e4b1f65502927379148b6622686223125c",
- "sha256:e2ede7c1d45e65e209d6093b762e98e8318ddeff95317d07a27a2140b80cfd24",
- "sha256:e4ef9c164eb55123c62411f5936b5c2e521b12356037b6e1c2617cef45523d47",
- "sha256:eca2b7343524e7ba246cab8ff00cab47a2d6d54ada3b02772e908a45675722e2",
- "sha256:eee64c616adeff7db37cc37da4180a3a5b6177f5c46b187894e633f088fb5b28",
- "sha256:ef824cad1f980d27f26166f86856efe11eff9912c4fed97d3804820d43fa550c",
- "sha256:efc89291bd5a08855829a3c522df16d856455297cf35ae827a37edac45f466a7",
- "sha256:fa964bae817babece5aa2e8c1af841bebb6d0b9add8e637548809d040443fee0",
- "sha256:ff37757e068ae606659c28c3bd0d923f9d29a85de79bf25b2b34b148473b5025"
+ "sha256:15cf13a6896048d6d947bf7d222f36e4809ab926894beb748fc9caa14605d9c3",
+ "sha256:1daa3eceed220f9fdb80d5ff950dd95112cd27f70d004c7918ca6dfc6c47054c",
+ "sha256:1e44a022500d944d42f94df76727ba3fc0a5c0b672c358b61067abb88caee7a0",
+ "sha256:25dbf1110d70bab68a74b4b9d74f30e99b177cde3388e07cc7272f2168bd1477",
+ "sha256:3230d1003eec018ad4a472d254991e34241e0bbd513e97a29727c7c2f637bd2a",
+ "sha256:3dbb72eaeea5763676a1a1efd9b427a048c97c39ed92e13336e726117d0b72bf",
+ "sha256:5012d3b8d5a500834783689a5d2292fe06ec75dc86ee1ccdad04b6f5bf231691",
+ "sha256:51bc7710b13a2ae0c726f69756cf7ffd4362f4ac36546e243136187cfcc8aa73",
+ "sha256:527b4f316e6bf7755082a783726da20671a0cc388b786a64417780b90565b987",
+ "sha256:722e4557c8039aad9592c6a4213db75da08c2cd9945320220634f637251c3894",
+ "sha256:76e2057e8ffba5472fd28a3a010431fd9e928885ff480cb278877c6e9943cc2e",
+ "sha256:77afca04240c40450c331fa796b3eab6f1e15c5ecf8bf2b8bee9706cd5452fef",
+ "sha256:7afad9835e7a651d3551eab18cbc0fdb888f0a6136169fbef0662d9cdc9987cf",
+ "sha256:9bea19ac2f08672636350f203db89382121c9c2ade85d945953ef3c8cf9d2a68",
+ "sha256:a8b8ac7876bc3598e43e2603f772d2353d9931709345ad6c1149009fd1bc81b8",
+ "sha256:b0840b45187699affd4c6588286d429cd79a99d509fe3de0f209594669bb0954",
+ "sha256:b26aaf69713e5674efbde4d728fb7124e429c9466aeaf5f4a7e9e699b12c9fe2",
+ "sha256:b63dd43f455ba878e5e9f80ba4f748c0a2156dde6e0e6e690310e24d6e8caf40",
+ "sha256:be18f4ae5a9e46edae3f329de2191747966a34a3d93046dbdf897319923923bc",
+ "sha256:c312e57847db2526bc92b9bfa78266bfbaabac3fdcd751df4d062cd4c23e46dc",
+ "sha256:c60097190fe9dc2b329a0eb03393e2e0829156a589bd732e70794c0dd804258e",
+ "sha256:c62a2143e1313944bf4a5ab34fd3b4be15367a02e9478b0ce800cb510e3bbb9d",
+ "sha256:cc1109f54a14d940b8512ee9f1c3975c181bbb200306c6d8b87d93376538782f",
+ "sha256:cd60f507c125ac0ad83f05803063bed27e50fa903b9c2cfee3f8a6867ca600fc",
+ "sha256:d513cc3db248e566e07a0da99c230aca3556d9b09ed02f420664e2da97eac301",
+ "sha256:d649dc0bcace6fcdb446ae02b98798a856593b19b637c1b9af8edadf2b150bea",
+ "sha256:d7008a6796095a79544f4da1ee49418901961c97ca9e9d44904205ff7d6aa8cb",
+ "sha256:da93027835164b8223e8e5af2cf902a4c80ed93cb0909417234f4a9df3bcd9af",
+ "sha256:e69215621707119c6baf99bda014a45b999d37602cb7043d943c76a59b05bf52",
+ "sha256:ea9525e0fef2de9208250d6c5aeeee0138921057cd67fcef90fbed49c4d62d37",
+ "sha256:fca1669d464f0c9831fd10be2eef6b86f5ebd76c724d1e0706ebdff86bb4adf0"
],
"index": "pypi",
- "version": "==4.5.4"
+ "version": "==5.0.3"
},
"distlib": {
"hashes": [
- "sha256:2e166e231a26b36d6dfe35a48c4464346620f8645ed0ace01ee31822b288de21"
+ "sha256:2e166e231a26b36d6dfe35a48c4464346620f8645ed0ace01ee31822b288de21",
+ "sha256:9b183fb98f4870e02d315d5d17baef14be74c339d827346cae544f5597698555"
],
"version": "==0.3.0"
},
@@ -691,11 +715,11 @@
},
"flake8-bugbear": {
"hashes": [
- "sha256:d8c466ea79d5020cb20bf9f11cf349026e09517a42264f313d3f6fddb83e0571",
- "sha256:ded4d282778969b5ab5530ceba7aa1a9f1b86fa7618fc96a19a1d512331640f8"
+ "sha256:a3ddc03ec28ba2296fc6f89444d1c946a6b76460f859795b35b77d4920a51b63",
+ "sha256:bd02e4b009fb153fe6072c31c52aeab5b133d508095befb2ffcf3b41c4823162"
],
"index": "pypi",
- "version": "==19.8.0"
+ "version": "==20.1.4"
},
"flake8-docstrings": {
"hashes": [
@@ -723,11 +747,11 @@
},
"flake8-tidy-imports": {
"hashes": [
- "sha256:1c476aabc6e8db26dc75278464a3a392dba0ea80562777c5f13fd5cdf2646154",
- "sha256:b3f5b96affd0f57cacb6621ed28286ce67edaca807757b51227043ebf7b136a1"
+ "sha256:8aa34384b45137d4cf33f5818b8e7897dc903b1d1e10a503fa7dd193a9a710ba",
+ "sha256:b26461561bcc80e8012e46846630ecf0aaa59314f362a94cb7800dfdb32fa413"
],
"index": "pypi",
- "version": "==2.0.0"
+ "version": "==4.0.0"
},
"flake8-todo": {
"hashes": [
@@ -750,14 +774,6 @@
],
"version": "==2.9"
},
- "importlib-metadata": {
- "hashes": [
- "sha256:06f5b3a99029c7134207dd882428a66992a9de2bef7c2b699b5641f9886c3302",
- "sha256:b97607a1a18a5100839aec1dc26a1ea17ee0d93b20b0f008d80a5a050afb200b"
- ],
- "markers": "python_version < '3.8'",
- "version": "==1.5.0"
- },
"mccabe": {
"hashes": [
"sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42",
@@ -780,11 +796,11 @@
},
"pre-commit": {
"hashes": [
- "sha256:8f48d8637bdae6fa70cc97db9c1dd5aa7c5c8bf71968932a380628c25978b850",
- "sha256:f92a359477f3252452ae2e8d3029de77aec59415c16ae4189bcfba40b757e029"
+ "sha256:09ebe467f43ce24377f8c2f200fe3cd2570d328eb2ce0568c8e96ce19da45fa6",
+ "sha256:f8d555e31e2051892c7f7b3ad9f620bd2c09271d87e9eedb2ad831737d6211eb"
],
"index": "pypi",
- "version": "==1.21.0"
+ "version": "==2.1.1"
},
"pycodestyle": {
"hashes": [
@@ -868,62 +884,27 @@
],
"version": "==0.10.0"
},
- "typed-ast": {
- "hashes": [
- "sha256:0666aa36131496aed8f7be0410ff974562ab7eeac11ef351def9ea6fa28f6355",
- "sha256:0c2c07682d61a629b68433afb159376e24e5b2fd4641d35424e462169c0a7919",
- "sha256:249862707802d40f7f29f6e1aad8d84b5aa9e44552d2cc17384b209f091276aa",
- "sha256:24995c843eb0ad11a4527b026b4dde3da70e1f2d8806c99b7b4a7cf491612652",
- "sha256:269151951236b0f9a6f04015a9004084a5ab0d5f19b57de779f908621e7d8b75",
- "sha256:4083861b0aa07990b619bd7ddc365eb7fa4b817e99cf5f8d9cf21a42780f6e01",
- "sha256:498b0f36cc7054c1fead3d7fc59d2150f4d5c6c56ba7fb150c013fbc683a8d2d",
- "sha256:4e3e5da80ccbebfff202a67bf900d081906c358ccc3d5e3c8aea42fdfdfd51c1",
- "sha256:6daac9731f172c2a22ade6ed0c00197ee7cc1221aa84cfdf9c31defeb059a907",
- "sha256:715ff2f2df46121071622063fc7543d9b1fd19ebfc4f5c8895af64a77a8c852c",
- "sha256:73d785a950fc82dd2a25897d525d003f6378d1cb23ab305578394694202a58c3",
- "sha256:8c8aaad94455178e3187ab22c8b01a3837f8ee50e09cf31f1ba129eb293ec30b",
- "sha256:8ce678dbaf790dbdb3eba24056d5364fb45944f33553dd5869b7580cdbb83614",
- "sha256:aaee9905aee35ba5905cfb3c62f3e83b3bec7b39413f0a7f19be4e547ea01ebb",
- "sha256:bcd3b13b56ea479b3650b82cabd6b5343a625b0ced5429e4ccad28a8973f301b",
- "sha256:c9e348e02e4d2b4a8b2eedb48210430658df6951fa484e59de33ff773fbd4b41",
- "sha256:d205b1b46085271b4e15f670058ce182bd1199e56b317bf2ec004b6a44f911f6",
- "sha256:d43943ef777f9a1c42bf4e552ba23ac77a6351de620aa9acf64ad54933ad4d34",
- "sha256:d5d33e9e7af3b34a40dc05f498939f0ebf187f07c385fd58d591c533ad8562fe",
- "sha256:fc0fea399acb12edbf8a628ba8d2312f583bdbdb3335635db062fa98cf71fca4",
- "sha256:fe460b922ec15dd205595c9b5b99e2f056fd98ae8f9f56b888e7a17dc2b757e7"
- ],
- "markers": "python_version < '3.8'",
- "version": "==1.4.1"
- },
"unittest-xml-reporting": {
"hashes": [
- "sha256:358bbdaf24a26d904cc1c26ef3078bca7fc81541e0a54c8961693cc96a6f35e0",
- "sha256:9d28ddf6524cf0ff9293f61bd12e792de298f8561a5c945acea63fb437789e0e"
+ "sha256:74eaf7739a7957a74f52b8187c5616f61157372189bef0a32ba5c30bbc00e58a",
+ "sha256:e09b8ae70cce9904cdd331f53bf929150962869a5324ab7ff3dd6c8b87e01f7d"
],
"index": "pypi",
- "version": "==2.5.2"
+ "version": "==3.0.2"
},
"urllib3": {
"hashes": [
- "sha256:2393a695cd12afedd0dcb26fe5d50d0cf248e5a66f75dbd89a3d4eb333a61af4",
- "sha256:a637e5fae88995b256e3409dc4d52c2e2e0ba32c42a6365fee8bbd2238de3cfb"
+ "sha256:2f3db8b19923a873b3e5256dc9c2dedfa883e33d87c690d9c7913e1f40673cdc",
+ "sha256:87716c2d2a7121198ebcb7ce7cccf6ce5e9ba539041cfbaeecfb641dc0bf6acc"
],
- "index": "pypi",
- "version": "==1.24.3"
+ "version": "==1.25.8"
},
"virtualenv": {
"hashes": [
- "sha256:08f3623597ce73b85d6854fb26608a6f39ee9d055c81178dc6583803797f8994",
- "sha256:de2cbdd5926c48d7b84e0300dea9e8f276f61d186e8e49223d71d91250fbaebd"
- ],
- "version": "==20.0.4"
- },
- "zipp": {
- "hashes": [
- "sha256:12248a63bbdf7548f89cb4c7cda4681e537031eda29c02ea29674bc6854460c2",
- "sha256:7c0f8e91abc0dc07a5068f315c52cb30c66bfbc581e5b50704c8a2f6ebae794a"
+ "sha256:30ea90b21dabd11da5f509710ad3be2ae47d40ccbc717dfdd2efe4367c10f598",
+ "sha256:4a36a96d785428278edd389d9c36d763c5755844beb7509279194647b1ef47f1"
],
- "version": "==3.0.0"
+ "version": "==20.0.7"
}
}
}
diff --git a/azure-pipelines.yml b/azure-pipelines.yml
index 874364a6f..35dea089a 100644
--- a/azure-pipelines.yml
+++ b/azure-pipelines.yml
@@ -9,7 +9,7 @@ jobs:
- job: test
displayName: 'Lint & Test'
pool:
- vmImage: ubuntu-16.04
+ vmImage: ubuntu-18.04
variables:
PIP_CACHE_DIR: ".cache/pip"
@@ -18,7 +18,7 @@ jobs:
- task: UsePythonVersion@0
displayName: 'Set Python version'
inputs:
- versionSpec: '3.7.x'
+ versionSpec: '3.8.x'
addToPath: true
- script: pip install pipenv
diff --git a/bot/__init__.py b/bot/__init__.py
index f7a410706..c9dbc3f40 100644
--- a/bot/__init__.py
+++ b/bot/__init__.py
@@ -1,9 +1,11 @@
import logging
import os
import sys
-from logging import Logger, StreamHandler, handlers
+from logging import Logger, handlers
from pathlib import Path
+import coloredlogs
+
TRACE_LEVEL = logging.TRACE = 5
logging.addLevelName(TRACE_LEVEL, "TRACE")
@@ -25,10 +27,9 @@ Logger.trace = monkeypatch_trace
DEBUG_MODE = 'local' in os.environ.get("SITE_URL", "local")
-log_format = logging.Formatter("%(asctime)s | %(name)s | %(levelname)s | %(message)s")
-
-stream_handler = StreamHandler(stream=sys.stdout)
-stream_handler.setFormatter(log_format)
+log_level = TRACE_LEVEL if DEBUG_MODE else logging.INFO
+format_string = "%(asctime)s | %(name)s | %(levelname)s | %(message)s"
+log_format = logging.Formatter(format_string)
log_file = Path("logs", "bot.log")
log_file.parent.mkdir(exist_ok=True)
@@ -36,10 +37,25 @@ file_handler = handlers.RotatingFileHandler(log_file, maxBytes=5242880, backupCo
file_handler.setFormatter(log_format)
root_log = logging.getLogger()
-root_log.setLevel(TRACE_LEVEL if DEBUG_MODE else logging.INFO)
-root_log.addHandler(stream_handler)
+root_log.setLevel(log_level)
root_log.addHandler(file_handler)
+if "COLOREDLOGS_LEVEL_STYLES" not in os.environ:
+ coloredlogs.DEFAULT_LEVEL_STYLES = {
+ **coloredlogs.DEFAULT_LEVEL_STYLES,
+ "trace": {"color": 246},
+ "critical": {"background": "red"},
+ "debug": coloredlogs.DEFAULT_LEVEL_STYLES["info"]
+ }
+
+if "COLOREDLOGS_LOG_FORMAT" not in os.environ:
+ coloredlogs.DEFAULT_LOG_FORMAT = format_string
+
+if "COLOREDLOGS_LOG_LEVEL" not in os.environ:
+ coloredlogs.DEFAULT_LOG_LEVEL = log_level
+
+coloredlogs.install(logger=root_log, stream=sys.stdout)
+
logging.getLogger("discord").setLevel(logging.WARNING)
logging.getLogger("websockets").setLevel(logging.WARNING)
logging.getLogger(__name__)
diff --git a/bot/__main__.py b/bot/__main__.py
index 490163739..3df477a6d 100644
--- a/bot/__main__.py
+++ b/bot/__main__.py
@@ -7,10 +7,10 @@ from sentry_sdk.integrations.logging import LoggingIntegration
from bot import patches
from bot.bot import Bot
-from bot.constants import Bot as BotConfig, DEBUG_MODE
+from bot.constants import Bot as BotConfig
sentry_logging = LoggingIntegration(
- level=logging.TRACE,
+ level=logging.DEBUG,
event_level=logging.WARNING
)
@@ -31,6 +31,7 @@ bot.load_extension("bot.cogs.error_handler")
bot.load_extension("bot.cogs.filtering")
bot.load_extension("bot.cogs.logging")
bot.load_extension("bot.cogs.security")
+bot.load_extension("bot.cogs.config_verifier")
# Commands, etc
bot.load_extension("bot.cogs.antimalware")
@@ -40,10 +41,8 @@ bot.load_extension("bot.cogs.clean")
bot.load_extension("bot.cogs.extensions")
bot.load_extension("bot.cogs.help")
-# Only load this in production
-if not DEBUG_MODE:
- bot.load_extension("bot.cogs.doc")
- bot.load_extension("bot.cogs.verification")
+bot.load_extension("bot.cogs.doc")
+bot.load_extension("bot.cogs.verification")
# Feature cogs
bot.load_extension("bot.cogs.alias")
diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py
index 28e3e5d96..9e9e81364 100644
--- a/bot/cogs/antimalware.py
+++ b/bot/cogs/antimalware.py
@@ -4,7 +4,7 @@ from discord import Embed, Message, NotFound
from discord.ext.commands import Cog
from bot.bot import Bot
-from bot.constants import AntiMalware as AntiMalwareConfig, Channels, URLs
+from bot.constants import AntiMalware as AntiMalwareConfig, Channels, STAFF_ROLES, URLs
log = logging.getLogger(__name__)
@@ -18,7 +18,13 @@ class AntiMalware(Cog):
@Cog.listener()
async def on_message(self, message: Message) -> None:
"""Identify messages with prohibited attachments."""
- if not message.attachments:
+ # Return when message don't have attachment and don't moderate DMs
+ if not message.attachments or not message.guild:
+ return
+
+ # Check if user is staff, if is, return
+ # Since we only care that roles exist to iterate over, check for the attr rather than a User/Member instance
+ if hasattr(message.author, "roles") and any(role.id in STAFF_ROLES for role in message.author.roles):
return
embed = Embed()
diff --git a/bot/cogs/bot.py b/bot/cogs/bot.py
index 73b1e8f41..f17135877 100644
--- a/bot/cogs/bot.py
+++ b/bot/cogs/bot.py
@@ -34,13 +34,12 @@ class BotCog(Cog, name="Bot"):
Channels.help_5: 0,
Channels.help_6: 0,
Channels.help_7: 0,
- Channels.python: 0,
+ Channels.python_discussion: 0,
}
# These channels will also work, but will not be subject to cooldown
self.channel_whitelist = (
- Channels.bot,
- Channels.devtest,
+ Channels.bot_commands,
)
# Stores improperly formatted Python codeblock message ids and the corresponding bot message
diff --git a/bot/cogs/clean.py b/bot/cogs/clean.py
index 2104efe57..5cdf0b048 100644
--- a/bot/cogs/clean.py
+++ b/bot/cogs/clean.py
@@ -173,7 +173,7 @@ class Clean(Cog):
colour=Colour(Colours.soft_red),
title="Bulk message delete",
text=message,
- channel_id=Channels.modlog,
+ channel_id=Channels.mod_log,
)
@group(invoke_without_command=True, name="clean", aliases=["purge"])
diff --git a/bot/cogs/config_verifier.py b/bot/cogs/config_verifier.py
new file mode 100644
index 000000000..d72c6c22e
--- /dev/null
+++ b/bot/cogs/config_verifier.py
@@ -0,0 +1,40 @@
+import logging
+
+from discord.ext.commands import Cog
+
+from bot import constants
+from bot.bot import Bot
+
+
+log = logging.getLogger(__name__)
+
+
+class ConfigVerifier(Cog):
+ """Verify config on startup."""
+
+ def __init__(self, bot: Bot):
+ self.bot = bot
+ self.channel_verify_task = self.bot.loop.create_task(self.verify_channels())
+
+ async def verify_channels(self) -> None:
+ """
+ Verify channels.
+
+ If any channels in config aren't present in server, log them in a warning.
+ """
+ await self.bot.wait_until_guild_available()
+ server = self.bot.get_guild(constants.Guild.id)
+
+ server_channel_ids = {channel.id for channel in server.channels}
+ invalid_channels = [
+ channel_name for channel_name, channel_id in constants.Channels
+ if channel_id not in server_channel_ids
+ ]
+
+ if invalid_channels:
+ log.warning(f"Configured channels do not exist in server: {', '.join(invalid_channels)}.")
+
+
+def setup(bot: Bot) -> None:
+ """Load the ConfigVerifier cog."""
+ bot.add_cog(ConfigVerifier(bot))
diff --git a/bot/cogs/defcon.py b/bot/cogs/defcon.py
index 20961e0a2..cc0f79fe8 100644
--- a/bot/cogs/defcon.py
+++ b/bot/cogs/defcon.py
@@ -68,8 +68,8 @@ class Defcon(Cog):
except Exception: # Yikes!
log.exception("Unable to get DEFCON settings!")
- await self.bot.get_channel(Channels.devlog).send(
- f"<@&{Roles.admin}> **WARNING**: Unable to get DEFCON settings!"
+ await self.bot.get_channel(Channels.dev_log).send(
+ f"<@&{Roles.admins}> **WARNING**: Unable to get DEFCON settings!"
)
else:
@@ -118,7 +118,7 @@ class Defcon(Cog):
)
@group(name='defcon', aliases=('dc',), invoke_without_command=True)
- @with_role(Roles.admin, Roles.owner)
+ @with_role(Roles.admins, Roles.owners)
async def defcon_group(self, ctx: Context) -> None:
"""Check the DEFCON status or run a subcommand."""
await ctx.invoke(self.bot.get_command("help"), "defcon")
@@ -146,7 +146,7 @@ class Defcon(Cog):
await self.send_defcon_log(action, ctx.author, error)
@defcon_group.command(name='enable', aliases=('on', 'e'))
- @with_role(Roles.admin, Roles.owner)
+ @with_role(Roles.admins, Roles.owners)
async def enable_command(self, ctx: Context) -> None:
"""
Enable DEFCON mode. Useful in a pinch, but be sure you know what you're doing!
@@ -159,7 +159,7 @@ class Defcon(Cog):
await self.update_channel_topic()
@defcon_group.command(name='disable', aliases=('off', 'd'))
- @with_role(Roles.admin, Roles.owner)
+ @with_role(Roles.admins, Roles.owners)
async def disable_command(self, ctx: Context) -> None:
"""Disable DEFCON mode. Useful in a pinch, but be sure you know what you're doing!"""
self.enabled = False
@@ -167,7 +167,7 @@ class Defcon(Cog):
await self.update_channel_topic()
@defcon_group.command(name='status', aliases=('s',))
- @with_role(Roles.admin, Roles.owner)
+ @with_role(Roles.admins, Roles.owners)
async def status_command(self, ctx: Context) -> None:
"""Check the current status of DEFCON mode."""
embed = Embed(
@@ -179,7 +179,7 @@ class Defcon(Cog):
await ctx.send(embed=embed)
@defcon_group.command(name='days')
- @with_role(Roles.admin, Roles.owner)
+ @with_role(Roles.admins, Roles.owners)
async def days_command(self, ctx: Context, days: int) -> None:
"""Set how old an account must be to join the server, in days, with DEFCON mode enabled."""
self.days = timedelta(days=days)
diff --git a/bot/cogs/error_handler.py b/bot/cogs/error_handler.py
index 0abb7e521..261769efc 100644
--- a/bot/cogs/error_handler.py
+++ b/bot/cogs/error_handler.py
@@ -1,25 +1,14 @@
import contextlib
import logging
+import typing as t
-from discord.ext.commands import (
- BadArgument,
- BotMissingPermissions,
- CheckFailure,
- CommandError,
- CommandInvokeError,
- CommandNotFound,
- CommandOnCooldown,
- DisabledCommand,
- MissingPermissions,
- NoPrivateMessage,
- UserInputError,
-)
-from discord.ext.commands import Cog, Context
+from discord.ext.commands import Cog, Command, Context, errors
from sentry_sdk import push_scope
from bot.api import ResponseCodeError
from bot.bot import Bot
from bot.constants import Channels
+from bot.converters import TagNameConverter
from bot.decorators import InChannelCheckFailure
log = logging.getLogger(__name__)
@@ -32,118 +21,185 @@ class ErrorHandler(Cog):
self.bot = bot
@Cog.listener()
- async def on_command_error(self, ctx: Context, e: CommandError) -> None:
+ async def on_command_error(self, ctx: Context, e: errors.CommandError) -> None:
"""
Provide generic command error handling.
- Error handling is deferred to any local error handler, if present.
-
- Error handling emits a single error response, prioritized as follows:
- 1. If the name fails to match a command but matches a tag, the tag is invoked
- 2. Send a BadArgument error message to the invoking context & invoke the command's help
- 3. Send a UserInputError error message to the invoking context & invoke the command's help
- 4. Send a NoPrivateMessage error message to the invoking context
- 5. Send a BotMissingPermissions error message to the invoking context
- 6. Log a MissingPermissions error, no message is sent
- 7. Send a InChannelCheckFailure error message to the invoking context
- 8. Log CheckFailure, CommandOnCooldown, and DisabledCommand errors, no message is sent
- 9. For CommandInvokeErrors, response is based on the type of error:
- * 404: Error message is sent to the invoking context
- * 400: Log the resopnse JSON, no message is sent
- * 500 <= status <= 600: Error message is sent to the invoking context
- 10. Otherwise, handling is deferred to `handle_unexpected_error`
+ Error handling is deferred to any local error handler, if present. This is done by
+ checking for the presence of a `handled` attribute on the error.
+
+ Error handling emits a single error message in the invoking context `ctx` and a log message,
+ prioritised as follows:
+
+ 1. If the name fails to match a command but matches a tag, the tag is invoked
+ * If CommandNotFound is raised when invoking the tag (determined by the presence of the
+ `invoked_from_error_handler` attribute), this error is treated as being unexpected
+ and therefore sends an error message
+ * Commands in the verification channel are ignored
+ 2. UserInputError: see `handle_user_input_error`
+ 3. CheckFailure: see `handle_check_failure`
+ 4. CommandOnCooldown: send an error message in the invoking context
+ 5. ResponseCodeError: see `handle_api_error`
+ 6. Otherwise, if not a DisabledCommand, handling is deferred to `handle_unexpected_error`
"""
command = ctx.command
- parent = None
+ if hasattr(e, "handled"):
+ log.trace(f"Command {command} had its error already handled locally; ignoring.")
+ return
+
+ # Try to look for a tag with the command's name if the command isn't found.
+ if isinstance(e, errors.CommandNotFound) and not hasattr(ctx, "invoked_from_error_handler"):
+ if ctx.channel.id != Channels.verification:
+ await self.try_get_tag(ctx)
+ return # Exit early to avoid logging.
+ elif isinstance(e, errors.UserInputError):
+ await self.handle_user_input_error(ctx, e)
+ elif isinstance(e, errors.CheckFailure):
+ await self.handle_check_failure(ctx, e)
+ elif isinstance(e, errors.CommandOnCooldown):
+ await ctx.send(e)
+ elif isinstance(e, errors.CommandInvokeError):
+ if isinstance(e.original, ResponseCodeError):
+ await self.handle_api_error(ctx, e.original)
+ else:
+ await self.handle_unexpected_error(ctx, e.original)
+ return # Exit early to avoid logging.
+ elif not isinstance(e, errors.DisabledCommand):
+ # ConversionError, MaxConcurrencyReached, ExtensionError
+ await self.handle_unexpected_error(ctx, e)
+ return # Exit early to avoid logging.
+
+ log.debug(
+ f"Command {command} invoked by {ctx.message.author} with error "
+ f"{e.__class__.__name__}: {e}"
+ )
+
+ async def get_help_command(self, command: t.Optional[Command]) -> t.Tuple:
+ """Return the help command invocation args to display help for `command`."""
+ parent = None
if command is not None:
parent = command.parent
# Retrieve the help command for the invoked command.
if parent and command:
- help_command = (self.bot.get_command("help"), parent.name, command.name)
+ return self.bot.get_command("help"), parent.name, command.name
elif command:
- help_command = (self.bot.get_command("help"), command.name)
+ return self.bot.get_command("help"), command.name
else:
- help_command = (self.bot.get_command("help"),)
+ return self.bot.get_command("help")
- if hasattr(e, "handled"):
- log.trace(f"Command {command} had its error already handled locally; ignoring.")
+ async def try_get_tag(self, ctx: Context) -> None:
+ """
+ Attempt to display a tag by interpreting the command name as a tag name.
+
+ The invocation of tags get respects its checks. Any CommandErrors raised will be handled
+ by `on_command_error`, but the `invoked_from_error_handler` attribute will be added to
+ the context to prevent infinite recursion in the case of a CommandNotFound exception.
+ """
+ tags_get_command = self.bot.get_command("tags get")
+ ctx.invoked_from_error_handler = True
+
+ log_msg = "Cancelling attempt to fall back to a tag due to failed checks."
+ try:
+ if not await tags_get_command.can_run(ctx):
+ log.debug(log_msg)
+ return
+ except errors.CommandError as tag_error:
+ log.debug(log_msg)
+ await self.on_command_error(ctx, tag_error)
return
- # Try to look for a tag with the command's name if the command isn't found.
- if isinstance(e, CommandNotFound) and not hasattr(ctx, "invoked_from_error_handler"):
- if not ctx.channel.id == Channels.verification:
- tags_get_command = self.bot.get_command("tags get")
- ctx.invoked_from_error_handler = True
-
- log_msg = "Cancelling attempt to fall back to a tag due to failed checks."
- try:
- if not await tags_get_command.can_run(ctx):
- log.debug(log_msg)
- return
- except CommandError as tag_error:
- log.debug(log_msg)
- await self.on_command_error(ctx, tag_error)
- return
-
- # Return to not raise the exception
- with contextlib.suppress(ResponseCodeError):
- await ctx.invoke(tags_get_command, tag_name=ctx.invoked_with)
- return
- elif isinstance(e, BadArgument):
+ try:
+ tag_name = await TagNameConverter.convert(ctx, ctx.invoked_with)
+ except errors.BadArgument:
+ log.debug(
+ f"{ctx.author} tried to use an invalid command "
+ f"and the fallback tag failed validation in TagNameConverter."
+ )
+ else:
+ with contextlib.suppress(ResponseCodeError):
+ await ctx.invoke(tags_get_command, tag_name=tag_name)
+ # Return to not raise the exception
+ return
+
+ async def handle_user_input_error(self, ctx: Context, e: errors.UserInputError) -> None:
+ """
+ Send an error message in `ctx` for UserInputError, sometimes invoking the help command too.
+
+ * MissingRequiredArgument: send an error message with arg name and the help command
+ * TooManyArguments: send an error message and the help command
+ * BadArgument: send an error message and the help command
+ * BadUnionArgument: send an error message including the error produced by the last converter
+ * ArgumentParsingError: send an error message
+ * Other: send an error message and the help command
+ """
+ # TODO: use ctx.send_help() once PR #519 is merged.
+ help_command = await self.get_help_command(ctx.command)
+
+ if isinstance(e, errors.MissingRequiredArgument):
+ await ctx.send(f"Missing required argument `{e.param.name}`.")
+ await ctx.invoke(*help_command)
+ elif isinstance(e, errors.TooManyArguments):
+ await ctx.send(f"Too many arguments provided.")
+ await ctx.invoke(*help_command)
+ elif isinstance(e, errors.BadArgument):
await ctx.send(f"Bad argument: {e}\n")
await ctx.invoke(*help_command)
- elif isinstance(e, UserInputError):
+ elif isinstance(e, errors.BadUnionArgument):
+ await ctx.send(f"Bad argument: {e}\n```{e.errors[-1]}```")
+ elif isinstance(e, errors.ArgumentParsingError):
+ await ctx.send(f"Argument parsing error: {e}")
+ else:
await ctx.send("Something about your input seems off. Check the arguments:")
await ctx.invoke(*help_command)
- log.debug(
- f"Command {command} invoked by {ctx.message.author} with error "
- f"{e.__class__.__name__}: {e}"
- )
- elif isinstance(e, NoPrivateMessage):
- await ctx.send("Sorry, this command can't be used in a private message!")
- elif isinstance(e, BotMissingPermissions):
- await ctx.send(f"Sorry, it looks like I don't have the permissions I need to do that.")
- log.warning(
- f"The bot is missing permissions to execute command {command}: {e.missing_perms}"
- )
- elif isinstance(e, MissingPermissions):
- log.debug(
- f"{ctx.message.author} is missing permissions to invoke command {command}: "
- f"{e.missing_perms}"
+
+ @staticmethod
+ async def handle_check_failure(ctx: Context, e: errors.CheckFailure) -> None:
+ """
+ Send an error message in `ctx` for certain types of CheckFailure.
+
+ The following types are handled:
+
+ * BotMissingPermissions
+ * BotMissingRole
+ * BotMissingAnyRole
+ * NoPrivateMessage
+ * InChannelCheckFailure
+ """
+ bot_missing_errors = (
+ errors.BotMissingPermissions,
+ errors.BotMissingRole,
+ errors.BotMissingAnyRole
+ )
+
+ if isinstance(e, bot_missing_errors):
+ await ctx.send(
+ f"Sorry, it looks like I don't have the permissions or roles I need to do that."
)
- elif isinstance(e, InChannelCheckFailure):
+ elif isinstance(e, (InChannelCheckFailure, errors.NoPrivateMessage)):
await ctx.send(e)
- elif isinstance(e, (CheckFailure, CommandOnCooldown, DisabledCommand)):
- log.debug(
- f"Command {command} invoked by {ctx.message.author} with error "
- f"{e.__class__.__name__}: {e}"
- )
- elif isinstance(e, CommandInvokeError):
- if isinstance(e.original, ResponseCodeError):
- status = e.original.response.status
-
- if status == 404:
- await ctx.send("There does not seem to be anything matching your query.")
- elif status == 400:
- content = await e.original.response.json()
- log.debug(f"API responded with 400 for command {command}: %r.", content)
- await ctx.send("According to the API, your request is malformed.")
- elif 500 <= status < 600:
- await ctx.send("Sorry, there seems to be an internal issue with the API.")
- log.warning(f"API responded with {status} for command {command}")
- else:
- await ctx.send(f"Got an unexpected status code from the API (`{status}`).")
- log.warning(f"Unexpected API response for command {command}: {status}")
- else:
- await self.handle_unexpected_error(ctx, e.original)
+
+ @staticmethod
+ async def handle_api_error(ctx: Context, e: ResponseCodeError) -> None:
+ """Send an error message in `ctx` for ResponseCodeError and log it."""
+ if e.status == 404:
+ await ctx.send("There does not seem to be anything matching your query.")
+ log.debug(f"API responded with 404 for command {ctx.command}")
+ elif e.status == 400:
+ content = await e.response.json()
+ log.debug(f"API responded with 400 for command {ctx.command}: %r.", content)
+ await ctx.send("According to the API, your request is malformed.")
+ elif 500 <= e.status < 600:
+ await ctx.send("Sorry, there seems to be an internal issue with the API.")
+ log.warning(f"API responded with {e.status} for command {ctx.command}")
else:
- await self.handle_unexpected_error(ctx, e)
+ await ctx.send(f"Got an unexpected status code from the API (`{e.status}`).")
+ log.warning(f"Unexpected API response for command {ctx.command}: {e.status}")
@staticmethod
- async def handle_unexpected_error(ctx: Context, e: CommandError) -> None:
- """Generic handler for errors without an explicit handler."""
+ async def handle_unexpected_error(ctx: Context, e: errors.CommandError) -> None:
+ """Send a generic error message in `ctx` and log the exception as an error with exc_info."""
await ctx.send(
f"Sorry, an unexpected error occurred. Please let us know!\n\n"
f"```{e.__class__.__name__}: {e}```"
diff --git a/bot/cogs/eval.py b/bot/cogs/eval.py
index 9c729f28a..52136fc8d 100644
--- a/bot/cogs/eval.py
+++ b/bot/cogs/eval.py
@@ -174,14 +174,14 @@ async def func(): # (None,) -> Any
await ctx.send(f"```py\n{out}```", embed=embed)
@group(name='internal', aliases=('int',))
- @with_role(Roles.owner, Roles.admin)
+ @with_role(Roles.owners, Roles.admins)
async def internal_group(self, ctx: Context) -> None:
"""Internal commands. Top secret!"""
if not ctx.invoked_subcommand:
await ctx.invoke(self.bot.get_command("help"), "internal")
@internal_group.command(name='eval', aliases=('e',))
- @with_role(Roles.admin, Roles.owner)
+ @with_role(Roles.admins, Roles.owners)
async def eval(self, ctx: Context, *, code: str) -> None:
"""Run eval in a REPL-like format."""
code = code.strip("`")
diff --git a/bot/cogs/extensions.py b/bot/cogs/extensions.py
index f16e79fb7..b312e1a1d 100644
--- a/bot/cogs/extensions.py
+++ b/bot/cogs/extensions.py
@@ -221,7 +221,7 @@ class Extensions(commands.Cog):
# This cannot be static (must have a __func__ attribute).
def cog_check(self, ctx: Context) -> bool:
"""Only allow moderators and core developers to invoke the commands in this cog."""
- return with_role_check(ctx, *MODERATION_ROLES, Roles.core_developer)
+ return with_role_check(ctx, *MODERATION_ROLES, Roles.core_developers)
# This cannot be static (must have a __func__ attribute).
async def cog_command_error(self, ctx: Context, error: Exception) -> None:
diff --git a/bot/cogs/free.py b/bot/cogs/free.py
index 49cab6172..02c02d067 100644
--- a/bot/cogs/free.py
+++ b/bot/cogs/free.py
@@ -22,7 +22,7 @@ class Free(Cog):
PYTHON_HELP_ID = Categories.python_help
@command(name="free", aliases=('f',))
- @redirect_output(destination_channel=Channels.bot, bypass_roles=STAFF_ROLES)
+ @redirect_output(destination_channel=Channels.bot_commands, bypass_roles=STAFF_ROLES)
async def free(self, ctx: Context, user: Member = None, seek: int = 2) -> None:
"""
Lists free help channels by likeliness of availability.
diff --git a/bot/cogs/help.py b/bot/cogs/help.py
index fd5bbc3ca..744722220 100644
--- a/bot/cogs/help.py
+++ b/bot/cogs/help.py
@@ -507,7 +507,7 @@ class Help(DiscordCog):
"""Custom Embed Pagination Help feature."""
@commands.command('help')
- @redirect_output(destination_channel=Channels.bot, bypass_roles=STAFF_ROLES)
+ @redirect_output(destination_channel=Channels.bot_commands, bypass_roles=STAFF_ROLES)
async def new_help(self, ctx: Context, *commands) -> None:
"""Shows Command Help."""
try:
diff --git a/bot/cogs/information.py b/bot/cogs/information.py
index 13c8aabaa..49beca15b 100644
--- a/bot/cogs/information.py
+++ b/bot/cogs/information.py
@@ -152,8 +152,8 @@ class Information(Cog):
# Non-staff may only do this in #bot-commands
if not with_role_check(ctx, *constants.STAFF_ROLES):
- if not ctx.channel.id == constants.Channels.bot:
- raise InChannelCheckFailure(constants.Channels.bot)
+ if not ctx.channel.id == constants.Channels.bot_commands:
+ raise InChannelCheckFailure(constants.Channels.bot_commands)
embed = await self.create_user_embed(ctx, user)
@@ -332,7 +332,7 @@ class Information(Cog):
@cooldown_with_role_bypass(2, 60 * 3, BucketType.member, bypass_roles=constants.STAFF_ROLES)
@group(invoke_without_command=True)
- @in_channel(constants.Channels.bot, bypass_roles=constants.STAFF_ROLES)
+ @in_channel(constants.Channels.bot_commands, bypass_roles=constants.STAFF_ROLES)
async def raw(self, ctx: Context, *, message: Message, json: bool = False) -> None:
"""Shows information about the raw API response."""
# I *guess* it could be deleted right as the command is invoked but I felt like it wasn't worth handling
diff --git a/bot/cogs/jams.py b/bot/cogs/jams.py
index 985f28ce5..1d062b0c2 100644
--- a/bot/cogs/jams.py
+++ b/bot/cogs/jams.py
@@ -18,7 +18,7 @@ class CodeJams(commands.Cog):
self.bot = bot
@commands.command()
- @with_role(Roles.admin)
+ @with_role(Roles.admins)
async def createteam(self, ctx: commands.Context, team_name: str, members: commands.Greedy[Member]) -> None:
"""
Create team channels (voice and text) in the Code Jams category, assign roles, and add overwrites for the team.
@@ -95,10 +95,10 @@ class CodeJams(commands.Cog):
)
# Assign team leader role
- await members[0].add_roles(ctx.guild.get_role(Roles.team_leader))
+ await members[0].add_roles(ctx.guild.get_role(Roles.team_leaders))
# Assign rest of roles
- jammer_role = ctx.guild.get_role(Roles.jammer)
+ jammer_role = ctx.guild.get_role(Roles.jammers)
for member in members:
await member.add_roles(jammer_role)
diff --git a/bot/cogs/logging.py b/bot/cogs/logging.py
index dbd76672f..94fa2b139 100644
--- a/bot/cogs/logging.py
+++ b/bot/cogs/logging.py
@@ -34,7 +34,7 @@ class Logging(Cog):
)
if not DEBUG_MODE:
- await self.bot.get_channel(Channels.devlog).send(embed=embed)
+ await self.bot.get_channel(Channels.dev_log).send(embed=embed)
def setup(bot: Bot) -> None:
diff --git a/bot/cogs/moderation/infractions.py b/bot/cogs/moderation/infractions.py
index f4e296df9..9ea17b2b3 100644
--- a/bot/cogs/moderation/infractions.py
+++ b/bot/cogs/moderation/infractions.py
@@ -313,6 +313,6 @@ class Infractions(InfractionScheduler, commands.Cog):
async def cog_command_error(self, ctx: Context, error: Exception) -> None:
"""Send a notification to the invoking context on a Union failure."""
if isinstance(error, commands.BadUnionArgument):
- if discord.User in error.converters:
+ if discord.User in error.converters or discord.Member in error.converters:
await ctx.send(str(error.errors[0]))
error.handled = True
diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py
index f2964cd78..35448f682 100644
--- a/bot/cogs/moderation/management.py
+++ b/bot/cogs/moderation/management.py
@@ -1,4 +1,3 @@
-import asyncio
import logging
import textwrap
import typing as t
@@ -129,12 +128,13 @@ class ModManagement(commands.Cog):
# Re-schedule infraction if the expiration has been updated
if 'expires_at' in request_data:
- self.infractions_cog.cancel_task(new_infraction['id'])
+ # A scheduled task should only exist if the old infraction wasn't permanent
+ if old_infraction['expires_at']:
+ self.infractions_cog.cancel_task(new_infraction['id'])
# If the infraction was not marked as permanent, schedule a new expiration task
if request_data['expires_at']:
- loop = asyncio.get_event_loop()
- self.infractions_cog.schedule_task(loop, new_infraction['id'], new_infraction)
+ self.infractions_cog.schedule_task(new_infraction['id'], new_infraction)
log_text += f"""
Previous expiry: {old_infraction['expires_at'] or "Permanent"}
diff --git a/bot/cogs/moderation/modlog.py b/bot/cogs/moderation/modlog.py
index e8ae0dbe6..59ae6b587 100644
--- a/bot/cogs/moderation/modlog.py
+++ b/bot/cogs/moderation/modlog.py
@@ -87,7 +87,7 @@ class ModLog(Cog, name="ModLog"):
title: t.Optional[str],
text: str,
thumbnail: t.Optional[t.Union[str, discord.Asset]] = None,
- channel_id: int = Channels.modlog,
+ channel_id: int = Channels.mod_log,
ping_everyone: bool = False,
files: t.Optional[t.List[discord.File]] = None,
content: t.Optional[str] = None,
@@ -377,7 +377,7 @@ class ModLog(Cog, name="ModLog"):
Icons.user_ban, Colours.soft_red,
"User banned", f"{member} (`{member.id}`)",
thumbnail=member.avatar_url_as(static_format="png"),
- channel_id=Channels.userlog
+ channel_id=Channels.user_log
)
@Cog.listener()
@@ -399,7 +399,7 @@ class ModLog(Cog, name="ModLog"):
Icons.sign_in, Colours.soft_green,
"User joined", message,
thumbnail=member.avatar_url_as(static_format="png"),
- channel_id=Channels.userlog
+ channel_id=Channels.user_log
)
@Cog.listener()
@@ -416,7 +416,7 @@ class ModLog(Cog, name="ModLog"):
Icons.sign_out, Colours.soft_red,
"User left", f"{member} (`{member.id}`)",
thumbnail=member.avatar_url_as(static_format="png"),
- channel_id=Channels.userlog
+ channel_id=Channels.user_log
)
@Cog.listener()
@@ -433,7 +433,7 @@ class ModLog(Cog, name="ModLog"):
Icons.user_unban, Colour.blurple(),
"User unbanned", f"{member} (`{member.id}`)",
thumbnail=member.avatar_url_as(static_format="png"),
- channel_id=Channels.modlog
+ channel_id=Channels.mod_log
)
@Cog.listener()
@@ -529,7 +529,7 @@ class ModLog(Cog, name="ModLog"):
Icons.user_update, Colour.blurple(),
"Member updated", message,
thumbnail=after.avatar_url_as(static_format="png"),
- channel_id=Channels.userlog
+ channel_id=Channels.user_log
)
@Cog.listener()
@@ -538,7 +538,7 @@ class ModLog(Cog, name="ModLog"):
channel = message.channel
author = message.author
- if message.guild.id != GuildConstant.id or channel.id in GuildConstant.ignored:
+ if message.guild.id != GuildConstant.id or channel.id in GuildConstant.modlog_blacklist:
return
self._cached_deletes.append(message.id)
@@ -591,7 +591,7 @@ class ModLog(Cog, name="ModLog"):
@Cog.listener()
async def on_raw_message_delete(self, event: discord.RawMessageDeleteEvent) -> None:
"""Log raw message delete event to message change log."""
- if event.guild_id != GuildConstant.id or event.channel_id in GuildConstant.ignored:
+ if event.guild_id != GuildConstant.id or event.channel_id in GuildConstant.modlog_blacklist:
return
await asyncio.sleep(1) # Wait here in case the normal event was fired
@@ -635,7 +635,7 @@ class ModLog(Cog, name="ModLog"):
if (
not msg_before.guild
or msg_before.guild.id != GuildConstant.id
- or msg_before.channel.id in GuildConstant.ignored
+ or msg_before.channel.id in GuildConstant.modlog_blacklist
or msg_before.author.bot
):
return
@@ -717,7 +717,7 @@ class ModLog(Cog, name="ModLog"):
if (
not message.guild
or message.guild.id != GuildConstant.id
- or message.channel.id in GuildConstant.ignored
+ or message.channel.id in GuildConstant.modlog_blacklist
or message.author.bot
):
return
@@ -769,7 +769,7 @@ class ModLog(Cog, name="ModLog"):
"""Log member voice state changes to the voice log channel."""
if (
member.guild.id != GuildConstant.id
- or (before.channel and before.channel.id in GuildConstant.ignored)
+ or (before.channel and before.channel.id in GuildConstant.modlog_blacklist)
):
return
diff --git a/bot/cogs/moderation/scheduler.py b/bot/cogs/moderation/scheduler.py
index 3c5185468..f0b6b2c48 100644
--- a/bot/cogs/moderation/scheduler.py
+++ b/bot/cogs/moderation/scheduler.py
@@ -1,3 +1,4 @@
+import asyncio
import logging
import textwrap
import typing as t
@@ -48,7 +49,7 @@ class InfractionScheduler(Scheduler):
)
for infraction in infractions:
if infraction["expires_at"] is not None and infraction["type"] in supported_infractions:
- self.schedule_task(self.bot.loop, infraction["id"], infraction)
+ self.schedule_task(infraction["id"], infraction)
async def reapply_infraction(
self,
@@ -150,7 +151,7 @@ class InfractionScheduler(Scheduler):
await action_coro
if expiry:
# Schedule the expiration of the infraction.
- self.schedule_task(ctx.bot.loop, infraction["id"], infraction)
+ self.schedule_task(infraction["id"], infraction)
except discord.HTTPException as e:
# Accordingly display that applying the infraction failed.
confirm_msg = f":x: failed to apply"
@@ -307,7 +308,7 @@ class InfractionScheduler(Scheduler):
Infractions of unsupported types will raise a ValueError.
"""
guild = self.bot.get_guild(constants.Guild.id)
- mod_role = guild.get_role(constants.Roles.moderator)
+ mod_role = guild.get_role(constants.Roles.moderators)
user_id = infraction["user"]
actor = infraction["actor"]
type_ = infraction["type"]
@@ -427,4 +428,6 @@ class InfractionScheduler(Scheduler):
expiry = dateutil.parser.isoparse(infraction["expires_at"]).replace(tzinfo=None)
await time.wait_until(expiry)
- await self.deactivate_infraction(infraction)
+ # Because deactivate_infraction() explicitly cancels this scheduled task, it is shielded
+ # to avoid prematurely cancelling itself.
+ await asyncio.shield(self.deactivate_infraction(infraction))
diff --git a/bot/cogs/moderation/superstarify.py b/bot/cogs/moderation/superstarify.py
index c41874a95..893cb7f13 100644
--- a/bot/cogs/moderation/superstarify.py
+++ b/bot/cogs/moderation/superstarify.py
@@ -146,7 +146,7 @@ class Superstarify(InfractionScheduler, Cog):
log.debug(f"Changing nickname of {member} to {forced_nick}.")
self.mod_log.ignore(constants.Event.member_update, member.id)
await member.edit(nick=forced_nick, reason=reason)
- self.schedule_task(ctx.bot.loop, id_, infraction)
+ self.schedule_task(id_, infraction)
# Send a DM to the user to notify them of their new infraction.
await utils.notify_infraction(
diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py
index 4f6584aba..5a7fa100f 100644
--- a/bot/cogs/reddit.py
+++ b/bot/cogs/reddit.py
@@ -43,8 +43,8 @@ class Reddit(Cog):
def cog_unload(self) -> None:
"""Stop the loop task and revoke the access token when the cog is unloaded."""
self.auto_poster_loop.cancel()
- if self.access_token.expires_at < datetime.utcnow():
- self.revoke_access_token()
+ if self.access_token and self.access_token.expires_at > datetime.utcnow():
+ asyncio.create_task(self.revoke_access_token())
async def init_reddit_ready(self) -> None:
"""Sets the reddit webhook when the cog is loaded."""
@@ -83,7 +83,7 @@ class Reddit(Cog):
expires_at=datetime.utcnow() + timedelta(seconds=expiration)
)
- log.debug(f"New token acquired; expires on {self.access_token.expires_at}")
+ log.debug(f"New token acquired; expires on UTC {self.access_token.expires_at}")
return
else:
log.debug(
@@ -290,4 +290,7 @@ class Reddit(Cog):
def setup(bot: Bot) -> None:
"""Load the Reddit cog."""
+ if not RedditConfig.secret or not RedditConfig.client_id:
+ log.error("Credentials not provided, cog not loaded.")
+ return
bot.add_cog(Reddit(bot))
diff --git a/bot/cogs/reminders.py b/bot/cogs/reminders.py
index 041791056..24c279357 100644
--- a/bot/cogs/reminders.py
+++ b/bot/cogs/reminders.py
@@ -43,7 +43,6 @@ class Reminders(Scheduler, Cog):
)
now = datetime.utcnow()
- loop = asyncio.get_event_loop()
for reminder in response:
is_valid, *_ = self.ensure_valid_reminder(reminder, cancel_task=False)
@@ -57,7 +56,7 @@ class Reminders(Scheduler, Cog):
late = relativedelta(now, remind_at)
await self.send_reminder(reminder, late)
else:
- self.schedule_task(loop, reminder["id"], reminder)
+ self.schedule_task(reminder["id"], reminder)
def ensure_valid_reminder(
self,
@@ -112,9 +111,6 @@ class Reminders(Scheduler, Cog):
log.debug(f"Deleting reminder {reminder_id} (the user has been reminded).")
await self._delete_reminder(reminder_id)
- # Now we can begone with it from our schedule list.
- self.cancel_task(reminder_id)
-
async def _delete_reminder(self, reminder_id: str, cancel_task: bool = True) -> None:
"""Delete a reminder from the database, given its ID, and cancel the running task."""
await self.bot.api_client.delete('bot/reminders/' + str(reminder_id))
@@ -125,10 +121,11 @@ class Reminders(Scheduler, Cog):
async def _reschedule_reminder(self, reminder: dict) -> None:
"""Reschedule a reminder object."""
- loop = asyncio.get_event_loop()
-
+ log.trace(f"Cancelling old task #{reminder['id']}")
self.cancel_task(reminder["id"])
- self.schedule_task(loop, reminder["id"], reminder)
+
+ log.trace(f"Scheduling new task #{reminder['id']}")
+ self.schedule_task(reminder["id"], reminder)
async def send_reminder(self, reminder: dict, late: relativedelta = None) -> None:
"""Send the reminder."""
@@ -226,8 +223,7 @@ class Reminders(Scheduler, Cog):
delivery_dt=expiration,
)
- loop = asyncio.get_event_loop()
- self.schedule_task(loop, reminder["id"], reminder)
+ self.schedule_task(reminder["id"], reminder)
@remind_group.command(name="list")
async def list_reminders(self, ctx: Context) -> t.Optional[discord.Message]:
diff --git a/bot/cogs/snekbox.py b/bot/cogs/snekbox.py
index da33e27b2..cff7c5786 100644
--- a/bot/cogs/snekbox.py
+++ b/bot/cogs/snekbox.py
@@ -1,10 +1,14 @@
+import asyncio
+import contextlib
import datetime
import logging
import re
import textwrap
+from functools import partial
from signal import Signals
from typing import Optional, Tuple
+from discord import HTTPException, Message, NotFound, Reaction, User
from discord.ext.commands import Cog, Context, command, guild_only
from bot.bot import Bot
@@ -34,7 +38,11 @@ RAW_CODE_REGEX = re.compile(
)
MAX_PASTE_LEN = 1000
-EVAL_ROLES = (Roles.helpers, Roles.moderator, Roles.admin, Roles.owner, Roles.rockstars, Roles.partners)
+EVAL_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners)
+
+SIGKILL = 9
+
+REEVAL_EMOJI = '\U0001f501' # :repeat:
class Snekbox(Cog):
@@ -101,7 +109,7 @@ class Snekbox(Cog):
if returncode is None:
msg = "Your eval job has failed"
error = stdout.strip()
- elif returncode == 128 + Signals.SIGKILL:
+ elif returncode == 128 + SIGKILL:
msg = "Your eval job timed out or ran out of memory"
elif returncode == 255:
msg = "Your eval job has failed"
@@ -135,7 +143,7 @@ class Snekbox(Cog):
"""
log.trace("Formatting output...")
- output = output.strip(" \n")
+ output = output.rstrip("\n")
original_output = output # To be uploaded to a pasting service if needed
paste_link = None
@@ -152,8 +160,8 @@ class Snekbox(Cog):
lines = output.count("\n")
if lines > 0:
- output = output.split("\n")[:10] # Only first 10 cause the rest is truncated anyway
- output = (f"{i:03d} | {line}" for i, line in enumerate(output, 1))
+ output = [f"{i:03d} | {line}" for i, line in enumerate(output.split('\n'), 1)]
+ output = output[:11] # Limiting to only 11 lines
output = "\n".join(output)
if lines > 10:
@@ -169,21 +177,84 @@ class Snekbox(Cog):
if truncated:
paste_link = await self.upload_output(original_output)
- output = output.strip()
- if not output:
- output = "[No output]"
+ output = output or "[No output]"
return output, paste_link
+ async def send_eval(self, ctx: Context, code: str) -> Message:
+ """
+ Evaluate code, format it, and send the output to the corresponding channel.
+
+ Return the bot response.
+ """
+ async with ctx.typing():
+ results = await self.post_eval(code)
+ msg, error = self.get_results_message(results)
+
+ if error:
+ output, paste_link = error, None
+ else:
+ output, paste_link = await self.format_output(results["stdout"])
+
+ icon = self.get_status_emoji(results)
+ msg = f"{ctx.author.mention} {icon} {msg}.\n\n```py\n{output}\n```"
+ if paste_link:
+ msg = f"{msg}\nFull output: {paste_link}"
+
+ response = await ctx.send(msg)
+ self.bot.loop.create_task(
+ wait_for_deletion(response, user_ids=(ctx.author.id,), client=ctx.bot)
+ )
+
+ log.info(f"{ctx.author}'s job had a return code of {results['returncode']}")
+ return response
+
+ async def continue_eval(self, ctx: Context, response: Message) -> Optional[str]:
+ """
+ Check if the eval session should continue.
+
+ Return the new code to evaluate or None if the eval session should be terminated.
+ """
+ _predicate_eval_message_edit = partial(predicate_eval_message_edit, ctx)
+ _predicate_emoji_reaction = partial(predicate_eval_emoji_reaction, ctx)
+
+ with contextlib.suppress(NotFound):
+ try:
+ _, new_message = await self.bot.wait_for(
+ 'message_edit',
+ check=_predicate_eval_message_edit,
+ timeout=10
+ )
+ await ctx.message.add_reaction(REEVAL_EMOJI)
+ await self.bot.wait_for(
+ 'reaction_add',
+ check=_predicate_emoji_reaction,
+ timeout=10
+ )
+
+ code = new_message.content.split(' ', maxsplit=1)[1]
+ await ctx.message.clear_reactions()
+ with contextlib.suppress(HTTPException):
+ await response.delete()
+
+ except asyncio.TimeoutError:
+ await ctx.message.clear_reactions()
+ return None
+
+ return code
+
@command(name="eval", aliases=("e",))
@guild_only()
- @in_channel(Channels.bot, hidden_channels=(Channels.esoteric,), bypass_roles=EVAL_ROLES)
+ @in_channel(Channels.bot_commands, hidden_channels=(Channels.esoteric,), bypass_roles=EVAL_ROLES)
async def eval_command(self, ctx: Context, *, code: str = None) -> None:
"""
Run Python code and get the results.
This command supports multiple lines of code, including code wrapped inside a formatted code
- block. We've done our best to make this safe, but do let us know if you manage to find an
+ block. Code can be re-evaluated by editing the original message within 10 seconds and
+ clicking the reaction that subsequently appears.
+
+ We've done our best to make this sandboxed, but do let us know if you manage to find an
issue with it!
"""
if ctx.author.id in self.jobs:
@@ -199,32 +270,28 @@ class Snekbox(Cog):
log.info(f"Received code from {ctx.author} for evaluation:\n{code}")
- self.jobs[ctx.author.id] = datetime.datetime.now()
- code = self.prepare_input(code)
+ while True:
+ self.jobs[ctx.author.id] = datetime.datetime.now()
+ code = self.prepare_input(code)
+ try:
+ response = await self.send_eval(ctx, code)
+ finally:
+ del self.jobs[ctx.author.id]
+
+ code = await self.continue_eval(ctx, response)
+ if not code:
+ break
+ log.info(f"Re-evaluating message {ctx.message.id}")
+
+
+def predicate_eval_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool:
+ """Return True if the edited message is the context message and the content was indeed modified."""
+ return new_msg.id == ctx.message.id and old_msg.content != new_msg.content
- try:
- async with ctx.typing():
- results = await self.post_eval(code)
- msg, error = self.get_results_message(results)
-
- if error:
- output, paste_link = error, None
- else:
- output, paste_link = await self.format_output(results["stdout"])
-
- icon = self.get_status_emoji(results)
- msg = f"{ctx.author.mention} {icon} {msg}.\n\n```py\n{output}\n```"
- if paste_link:
- msg = f"{msg}\nFull output: {paste_link}"
-
- response = await ctx.send(msg)
- self.bot.loop.create_task(
- wait_for_deletion(response, user_ids=(ctx.author.id,), client=ctx.bot)
- )
- log.info(f"{ctx.author}'s job had a return code of {results['returncode']}")
- finally:
- del self.jobs[ctx.author.id]
+def predicate_eval_emoji_reaction(ctx: Context, reaction: Reaction, user: User) -> bool:
+ """Return True if the reaction REEVAL_EMOJI was added by the context message author on this message."""
+ return reaction.message.id == ctx.message.id and user.id == ctx.author.id and str(reaction) == REEVAL_EMOJI
def setup(bot: Bot) -> None:
diff --git a/bot/cogs/sync/syncers.py b/bot/cogs/sync/syncers.py
index 6715ad6fb..c7ce54d65 100644
--- a/bot/cogs/sync/syncers.py
+++ b/bot/cogs/sync/syncers.py
@@ -23,7 +23,7 @@ _Diff = namedtuple('Diff', ('created', 'updated', 'deleted'))
class Syncer(abc.ABC):
"""Base class for synchronising the database with objects in the Discord cache."""
- _CORE_DEV_MENTION = f"<@&{constants.Roles.core_developer}> "
+ _CORE_DEV_MENTION = f"<@&{constants.Roles.core_developers}> "
_REACTION_EMOJIS = (constants.Emojis.check_mark, constants.Emojis.cross_mark)
def __init__(self, bot: Bot) -> None:
@@ -54,12 +54,12 @@ class Syncer(abc.ABC):
# Send to core developers if it's an automatic sync.
if not message:
log.trace("Message not provided for confirmation; creating a new one in dev-core.")
- channel = self.bot.get_channel(constants.Channels.devcore)
+ channel = self.bot.get_channel(constants.Channels.dev_core)
if not channel:
log.debug("Failed to get the dev-core channel from cache; attempting to fetch it.")
try:
- channel = await self.bot.fetch_channel(constants.Channels.devcore)
+ channel = await self.bot.fetch_channel(constants.Channels.dev_core)
except HTTPException:
log.exception(
f"Failed to fetch channel for sending sync confirmation prompt; "
@@ -93,7 +93,7 @@ class Syncer(abc.ABC):
`author` of the prompt.
"""
# For automatic syncs, check for the core dev role instead of an exact author
- has_role = any(constants.Roles.core_developer == role.id for role in user.roles)
+ has_role = any(constants.Roles.core_developers == role.id for role in user.roles)
return (
reaction.message.id == message.id
and not user.bot
@@ -125,17 +125,17 @@ class Syncer(abc.ABC):
except TimeoutError:
# reaction will remain none thus sync will be aborted in the finally block below.
log.debug(f"The {self.name} syncer confirmation prompt timed out.")
- finally:
- if str(reaction) == constants.Emojis.check_mark:
- log.trace(f"The {self.name} syncer was confirmed.")
- await message.edit(content=f':ok_hand: {mention}{self.name} sync will proceed.')
- return True
- else:
- log.warning(f"The {self.name} syncer was aborted or timed out!")
- await message.edit(
- content=f':warning: {mention}{self.name} sync aborted or timed out!'
- )
- return False
+
+ if str(reaction) == constants.Emojis.check_mark:
+ log.trace(f"The {self.name} syncer was confirmed.")
+ await message.edit(content=f':ok_hand: {mention}{self.name} sync will proceed.')
+ return True
+ else:
+ log.warning(f"The {self.name} syncer was aborted or timed out!")
+ await message.edit(
+ content=f':warning: {mention}{self.name} sync aborted or timed out!'
+ )
+ return False
@abc.abstractmethod
async def _get_diff(self, guild: Guild) -> _Diff:
diff --git a/bot/cogs/tags.py b/bot/cogs/tags.py
index b6360dfae..5da9a4148 100644
--- a/bot/cogs/tags.py
+++ b/bot/cogs/tags.py
@@ -15,8 +15,7 @@ from bot.pagination import LinePaginator
log = logging.getLogger(__name__)
TEST_CHANNELS = (
- Channels.devtest,
- Channels.bot,
+ Channels.bot_commands,
Channels.helpers
)
@@ -221,7 +220,7 @@ class Tags(Cog):
))
@tags_group.command(name='delete', aliases=('remove', 'rm', 'd'))
- @with_role(Roles.admin, Roles.owner)
+ @with_role(Roles.admins, Roles.owners)
async def delete_command(self, ctx: Context, *, tag_name: TagNameConverter) -> None:
"""Remove a tag from the database."""
await self.bot.api_client.delete(f'bot/tags/{tag_name}')
diff --git a/bot/cogs/utils.py b/bot/cogs/utils.py
index da278011a..94b9d6b5a 100644
--- a/bot/cogs/utils.py
+++ b/bot/cogs/utils.py
@@ -89,7 +89,7 @@ class Utils(Cog):
await ctx.message.channel.send(embed=pep_embed)
@command()
- @in_channel(Channels.bot, bypass_roles=STAFF_ROLES)
+ @in_channel(Channels.bot_commands, bypass_roles=STAFF_ROLES)
async def charinfo(self, ctx: Context, *, characters: str) -> None:
"""Shows you information on up to 25 unicode characters."""
match = re.match(r"<(a?):(\w+):(\d+)>", characters)
diff --git a/bot/cogs/verification.py b/bot/cogs/verification.py
index e3c396863..57b50c34f 100644
--- a/bot/cogs/verification.py
+++ b/bot/cogs/verification.py
@@ -30,15 +30,16 @@ your information removed here as well.
Feel free to review them at any point!
Additionally, if you'd like to receive notifications for the announcements we post in <#{Channels.announcements}> \
-from time to time, you can send `!subscribe` to <#{Channels.bot}> at any time to assign yourself the \
+from time to time, you can send `!subscribe` to <#{Channels.bot_commands}> at any time to assign yourself the \
**Announcements** role. We'll mention this role every time we make an announcement.
-If you'd like to unsubscribe from the announcement notifications, simply send `!unsubscribe` to <#{Channels.bot}>.
+If you'd like to unsubscribe from the announcement notifications, simply send `!unsubscribe` to \
+<#{Channels.bot_commands}>.
"""
PERIODIC_PING = (
f"@everyone To verify that you have read our rules, please type `{BotConfig.prefix}accept`."
- f" If you encounter any problems during the verification process, ping the <@&{Roles.admin}> role in this channel."
+ f" If you encounter any problems during the verification process, ping the <@&{Roles.admins}> role in this channel."
)
BOT_MESSAGE_DELETE_DELAY = 10
@@ -136,7 +137,7 @@ class Verification(Cog):
await ctx.message.delete()
@command(name='subscribe')
- @in_channel(Channels.bot)
+ @in_channel(Channels.bot_commands)
async def subscribe_command(self, ctx: Context, *_) -> None: # We don't actually care about the args
"""Subscribe to announcement notifications by assigning yourself the role."""
has_role = False
@@ -160,7 +161,7 @@ class Verification(Cog):
)
@command(name='unsubscribe')
- @in_channel(Channels.bot)
+ @in_channel(Channels.bot_commands)
async def unsubscribe_command(self, ctx: Context, *_) -> None: # We don't actually care about the args
"""Unsubscribe from announcement notifications by removing the role from yourself."""
has_role = False
diff --git a/bot/constants.py b/bot/constants.py
index 9bc331dc4..14f8dc094 100644
--- a/bot/constants.py
+++ b/bot/constants.py
@@ -186,6 +186,11 @@ class YAMLGetter(type):
def __getitem__(cls, name):
return cls.__getattr__(name)
+ def __iter__(cls):
+ """Return generator of key: value pairs of current constants class' config values."""
+ for name in cls.__annotations__:
+ yield name, getattr(cls, name)
+
# Dataclasses
class Bot(metaclass=YAMLGetter):
@@ -358,18 +363,16 @@ class Channels(metaclass=YAMLGetter):
section = "guild"
subsection = "channels"
- admins: int
admin_spam: int
+ admins: int
announcements: int
attachment_log: int
big_brother_logs: int
- bot: int
- checkpoint_test: int
+ bot_commands: int
defcon: int
- devcontrib: int
- devcore: int
- devlog: int
- devtest: int
+ dev_contrib: int
+ dev_core: int
+ dev_log: int
esoteric: int
help_0: int
help_1: int
@@ -382,19 +385,19 @@ class Channels(metaclass=YAMLGetter):
helpers: int
message_log: int
meta: int
+ mod_alerts: int
+ mod_log: int
mod_spam: int
mods: int
- mod_alerts: int
- modlog: int
off_topic_0: int
off_topic_1: int
off_topic_2: int
organisation: int
- python: int
+ python_discussion: int
reddit: int
talent_pool: int
- userlog: int
- user_event_a: int
+ user_event_announcements: int
+ user_log: int
verification: int
voice_log: int
@@ -414,19 +417,18 @@ class Roles(metaclass=YAMLGetter):
section = "guild"
subsection = "roles"
- admin: int
+ admins: int
announcements: int
- champion: int
- contributor: int
- core_developer: int
+ contributors: int
+ core_developers: int
helpers: int
- jammer: int
- moderator: int
+ jammers: int
+ moderators: int
muted: int
- owner: int
+ owners: int
partners: int
- rockstars: int
- team_leader: int
+ python_community: int
+ team_leaders: int
verified: int # This is the Developers role on PyDis, here named verified for readability reasons.
@@ -434,9 +436,12 @@ class Guild(metaclass=YAMLGetter):
section = "guild"
id: int
- ignored: List[int]
- staff_channels: List[int]
+ moderation_channels: List[int]
+ moderation_roles: List[int]
+ modlog_blacklist: List[int]
reminder_whitelist: List[int]
+ staff_channels: List[int]
+ staff_roles: List[int]
class Keys(metaclass=YAMLGetter):
section = "keys"
@@ -582,14 +587,14 @@ BOT_DIR = os.path.dirname(__file__)
PROJECT_ROOT = os.path.abspath(os.path.join(BOT_DIR, os.pardir))
# Default role combinations
-MODERATION_ROLES = Roles.moderator, Roles.admin, Roles.owner
-STAFF_ROLES = Roles.helpers, Roles.moderator, Roles.admin, Roles.owner
+MODERATION_ROLES = Guild.moderation_roles
+STAFF_ROLES = Guild.staff_roles
# Roles combinations
STAFF_CHANNELS = Guild.staff_channels
# Default Channel combinations
-MODERATION_CHANNELS = Channels.admins, Channels.admin_spam, Channels.mod_alerts, Channels.mods, Channels.mod_spam
+MODERATION_CHANNELS = Guild.moderation_channels
# Bot replies
diff --git a/bot/converters.py b/bot/converters.py
index cca57a02d..1945e1da3 100644
--- a/bot/converters.py
+++ b/bot/converters.py
@@ -141,40 +141,24 @@ class TagNameConverter(Converter):
@staticmethod
async def convert(ctx: Context, tag_name: str) -> str:
"""Lowercase & strip whitespace from proposed tag_name & ensure it's valid."""
- def is_number(value: str) -> bool:
- """Check to see if the input string is numeric."""
- try:
- float(value)
- except ValueError:
- return False
- return True
-
tag_name = tag_name.lower().strip()
# The tag name has at least one invalid character.
if ascii(tag_name)[1:-1] != tag_name:
- log.warning(f"{ctx.author} tried to put an invalid character in a tag name. "
- "Rejecting the request.")
raise BadArgument("Don't be ridiculous, you can't use that character!")
# The tag name is either empty, or consists of nothing but whitespace.
elif not tag_name:
- log.warning(f"{ctx.author} tried to create a tag with a name consisting only of whitespace. "
- "Rejecting the request.")
raise BadArgument("Tag names should not be empty, or filled with whitespace.")
- # The tag name is a number of some kind, we don't allow that.
- elif is_number(tag_name):
- log.warning(f"{ctx.author} tried to create a tag with a digit as its name. "
- "Rejecting the request.")
- raise BadArgument("Tag names can't be numbers.")
-
# The tag name is longer than 127 characters.
elif len(tag_name) > 127:
- log.warning(f"{ctx.author} tried to request a tag name with over 127 characters. "
- "Rejecting the request.")
raise BadArgument("Are you insane? That's way too long!")
+ # The tag name is ascii but does not contain any letters.
+ elif not any(character.isalpha() for character in tag_name):
+ raise BadArgument("Tag names must contain at least one letter.")
+
return tag_name
@@ -192,8 +176,6 @@ class TagContentConverter(Converter):
# The tag contents should not be empty, or filled with whitespace.
if not tag_content:
- log.warning(f"{ctx.author} tried to create a tag containing only whitespace. "
- "Rejecting the request.")
raise BadArgument("Tag contents should not be empty, or filled with whitespace.")
return tag_content
diff --git a/bot/utils/__init__.py b/bot/utils/__init__.py
index 8184be824..3e4b15ce4 100644
--- a/bot/utils/__init__.py
+++ b/bot/utils/__init__.py
@@ -1,5 +1,5 @@
from abc import ABCMeta
-from typing import Any, Generator, Hashable, Iterable
+from typing import Any, Hashable
from discord.ext.commands import CogMeta
@@ -64,13 +64,3 @@ class CaseInsensitiveDict(dict):
for k in list(self.keys()):
v = super(CaseInsensitiveDict, self).pop(k)
self.__setitem__(k, v)
-
-
-def chunks(iterable: Iterable, size: int) -> Generator[Any, None, None]:
- """
- Generator that allows you to iterate over any indexable collection in `size`-length chunks.
-
- Found: https://stackoverflow.com/a/312464/4022104
- """
- for i in range(0, len(iterable), size):
- yield iterable[i:i + size]
diff --git a/bot/utils/scheduling.py b/bot/utils/scheduling.py
index ee6c0a8e6..5760ec2d4 100644
--- a/bot/utils/scheduling.py
+++ b/bot/utils/scheduling.py
@@ -1,8 +1,9 @@
import asyncio
import contextlib
import logging
+import typing as t
from abc import abstractmethod
-from typing import Coroutine, Dict, Union
+from functools import partial
from bot.utils import CogABCMeta
@@ -13,12 +14,13 @@ class Scheduler(metaclass=CogABCMeta):
"""Task scheduler."""
def __init__(self):
+ # Keep track of the child cog's name so the logs are clear.
+ self.cog_name = self.__class__.__name__
- self.cog_name = self.__class__.__name__ # keep track of the child cog's name so the logs are clear.
- self.scheduled_tasks: Dict[str, asyncio.Task] = {}
+ self._scheduled_tasks: t.Dict[t.Hashable, asyncio.Task] = {}
@abstractmethod
- async def _scheduled_task(self, task_object: dict) -> None:
+ async def _scheduled_task(self, task_object: t.Any) -> None:
"""
A coroutine which handles the scheduling.
@@ -29,46 +31,73 @@ class Scheduler(metaclass=CogABCMeta):
then make a site API request to delete the reminder from the database.
"""
- def schedule_task(self, loop: asyncio.AbstractEventLoop, task_id: str, task_data: dict) -> None:
+ def schedule_task(self, task_id: t.Hashable, task_data: t.Any) -> None:
"""
Schedules a task.
- `task_data` is passed to `Scheduler._scheduled_expiration`
+ `task_data` is passed to the `Scheduler._scheduled_task()` coroutine.
"""
- if task_id in self.scheduled_tasks:
+ log.trace(f"{self.cog_name}: scheduling task #{task_id}...")
+
+ if task_id in self._scheduled_tasks:
log.debug(
f"{self.cog_name}: did not schedule task #{task_id}; task was already scheduled."
)
return
- task: asyncio.Task = create_task(loop, self._scheduled_task(task_data))
+ task = asyncio.create_task(self._scheduled_task(task_data))
+ task.add_done_callback(partial(self._task_done_callback, task_id))
- self.scheduled_tasks[task_id] = task
- log.debug(f"{self.cog_name}: scheduled task #{task_id}.")
+ self._scheduled_tasks[task_id] = task
+ log.debug(f"{self.cog_name}: scheduled task #{task_id} {id(task)}.")
- def cancel_task(self, task_id: str) -> None:
- """Un-schedules a task."""
- task = self.scheduled_tasks.get(task_id)
+ def cancel_task(self, task_id: t.Hashable) -> None:
+ """Unschedule the task identified by `task_id`."""
+ log.trace(f"{self.cog_name}: cancelling task #{task_id}...")
+ task = self._scheduled_tasks.get(task_id)
- if task is None:
- log.warning(f"{self.cog_name}: Failed to unschedule {task_id} (no task found).")
+ if not task:
+ log.warning(f"{self.cog_name}: failed to unschedule {task_id} (no task found).")
return
task.cancel()
- log.debug(f"{self.cog_name}: unscheduled task #{task_id}.")
- del self.scheduled_tasks[task_id]
+ del self._scheduled_tasks[task_id]
+
+ log.debug(f"{self.cog_name}: unscheduled task #{task_id} {id(task)}.")
+ def _task_done_callback(self, task_id: t.Hashable, done_task: asyncio.Task) -> None:
+ """
+ Delete the task and raise its exception if one exists.
-def create_task(loop: asyncio.AbstractEventLoop, coro_or_future: Union[Coroutine, asyncio.Future]) -> asyncio.Task:
- """Creates an asyncio.Task object from a coroutine or future object."""
- task: asyncio.Task = asyncio.ensure_future(coro_or_future, loop=loop)
+ If `done_task` and the task associated with `task_id` are different, then the latter
+ will not be deleted. In this case, a new task was likely rescheduled with the same ID.
+ """
+ log.trace(f"{self.cog_name}: performing done callback for task #{task_id} {id(done_task)}.")
- # Silently ignore exceptions in a callback (handles the CancelledError nonsense)
- task.add_done_callback(_silent_exception)
- return task
+ scheduled_task = self._scheduled_tasks.get(task_id)
+ if scheduled_task and done_task is scheduled_task:
+ # A task for the ID exists and its the same as the done task.
+ # Since this is the done callback, the task is already done so no need to cancel it.
+ log.trace(f"{self.cog_name}: deleting task #{task_id} {id(done_task)}.")
+ del self._scheduled_tasks[task_id]
+ elif scheduled_task:
+ # A new task was likely rescheduled with the same ID.
+ log.debug(
+ f"{self.cog_name}: the scheduled task #{task_id} {id(scheduled_task)} "
+ f"and the done task {id(done_task)} differ."
+ )
+ elif not done_task.cancelled():
+ log.warning(
+ f"{self.cog_name}: task #{task_id} not found while handling task {id(done_task)}! "
+ f"A task somehow got unscheduled improperly (i.e. deleted but not cancelled)."
+ )
-def _silent_exception(future: asyncio.Future) -> None:
- """Suppress future's exception."""
- with contextlib.suppress(Exception):
- future.exception()
+ with contextlib.suppress(asyncio.CancelledError):
+ exception = done_task.exception()
+ # Log the exception if one exists.
+ if exception:
+ log.error(
+ f"{self.cog_name}: error in task #{task_id} {id(scheduled_task)}!",
+ exc_info=exception
+ )
diff --git a/config-default.yml b/config-default.yml
index f70fe3c34..5788d1e12 100644
--- a/config-default.yml
+++ b/config-default.yml
@@ -111,78 +111,135 @@ guild:
id: 267624335836053506
categories:
- python_help: 356013061213126657
+ python_help: 356013061213126657
channels:
- admins: &ADMINS 365960823622991872
- admin_spam: &ADMIN_SPAM 563594791770914816
- admins_voice: &ADMINS_VOICE 500734494840717332
- announcements: 354619224620138496
- attachment_log: &ATTCH_LOG 649243850006855680
- big_brother_logs: &BBLOGS 468507907357409333
- bot: &BOT_CMD 267659945086812160
- checkpoint_test: 422077681434099723
- defcon: &DEFCON 464469101889454091
- devcontrib: &DEV_CONTRIB 635950537262759947
- devcore: 411200599653351425
- devlog: &DEVLOG 622895325144940554
- devtest: &DEVTEST 414574275865870337
- esoteric: 470884583684964352
- help_0: 303906576991780866
- help_1: 303906556754395136
- help_2: 303906514266226689
- help_3: 439702951246692352
- help_4: 451312046647148554
- help_5: 454941769734422538
- help_6: 587375753306570782
- help_7: 587375768556797982
- helpers: &HELPERS 385474242440986624
- message_log: &MESSAGE_LOG 467752170159079424
- meta: 429409067623251969
- mod_spam: &MOD_SPAM 620607373828030464
- mods: &MODS 305126844661760000
- mod_alerts: 473092532147060736
- modlog: &MODLOG 282638479504965634
- off_topic_0: 291284109232308226
- off_topic_1: 463035241142026251
- off_topic_2: 463035268514185226
- organisation: &ORGANISATION 551789653284356126
- python: 267624335836053506
- reddit: 458224812528238616
- staff_lounge: &STAFF_LOUNGE 464905259261755392
- staff_voice: &STAFF_VOICE 412375055910043655
- talent_pool: &TALENT_POOL 534321732593647616
- userlog: 528976905546760203
- user_event_a: &USER_EVENT_A 592000283102674944
- verification: 352442727016693763
- voice_log: 640292421988646961
-
- staff_channels: [*ADMINS, *ADMIN_SPAM, *MOD_SPAM, *MODS, *HELPERS, *ORGANISATION, *DEFCON]
- ignored: [*ADMINS, *MESSAGE_LOG, *MODLOG, *ADMINS_VOICE, *STAFF_VOICE, *ATTCH_LOG]
- reminder_whitelist: [*BOT_CMD, *DEV_CONTRIB]
+ announcements: 354619224620138496
+ user_event_announcements: &USER_EVENT_A 592000283102674944
+
+ # Development
+ dev_contrib: &DEV_CONTRIB 635950537262759947
+ dev_core: &DEV_CORE 411200599653351425
+ dev_log: &DEV_LOG 622895325144940554
+
+ # Discussion
+ meta: 429409067623251969
+ python_discussion: 267624335836053506
+
+ # Logs
+ attachment_log: &ATTACH_LOG 649243850006855680
+ message_log: &MESSAGE_LOG 467752170159079424
+ mod_log: &MOD_LOG 282638479504965634
+ user_log: 528976905546760203
+ voice_log: 640292421988646961
+
+ # Off-topic
+ off_topic_0: 291284109232308226
+ off_topic_1: 463035241142026251
+ off_topic_2: 463035268514185226
+
+ # Python Help
+ help_0: 303906576991780866
+ help_1: 303906556754395136
+ help_2: 303906514266226689
+ help_3: 439702951246692352
+ help_4: 451312046647148554
+ help_5: 454941769734422538
+ help_6: 587375753306570782
+ help_7: 587375768556797982
+
+ # Special
+ bot_commands: &BOT_CMD 267659945086812160
+ esoteric: 470884583684964352
+ reddit: 458224812528238616
+ verification: 352442727016693763
+
+ # Staff
+ admins: &ADMINS 365960823622991872
+ admin_spam: &ADMIN_SPAM 563594791770914816
+ defcon: &DEFCON 464469101889454091
+ helpers: &HELPERS 385474242440986624
+ mods: &MODS 305126844661760000
+ mod_alerts: &MOD_ALERTS 473092532147060736
+ mod_spam: &MOD_SPAM 620607373828030464
+ organisation: &ORGANISATION 551789653284356126
+ staff_lounge: &STAFF_LOUNGE 464905259261755392
+
+ # Voice
+ admins_voice: &ADMINS_VOICE 500734494840717332
+ staff_voice: &STAFF_VOICE 412375055910043655
+
+ # Watch
+ big_brother_logs: &BB_LOGS 468507907357409333
+ talent_pool: &TALENT_POOL 534321732593647616
+
+ staff_channels:
+ - *ADMINS
+ - *ADMIN_SPAM
+ - *DEFCON
+ - *HELPERS
+ - *MODS
+ - *MOD_SPAM
+ - *ORGANISATION
+
+ moderation_channels:
+ - *ADMINS
+ - *ADMIN_SPAM
+ - *MOD_ALERTS
+ - *MODS
+ - *MOD_SPAM
+
+ # Modlog cog ignores events which occur in these channels
+ modlog_blacklist:
+ - *ADMINS
+ - *ADMINS_VOICE
+ - *ATTACH_LOG
+ - *MESSAGE_LOG
+ - *MOD_LOG
+ - *STAFF_VOICE
+
+ reminder_whitelist:
+ - *BOT_CMD
+ - *DEV_CONTRIB
roles:
- admin: &ADMIN_ROLE 267628507062992896
- announcements: 463658397560995840
- champion: 430492892331769857
- contributor: 295488872404484098
- core_developer: 587606783669829632
- helpers: 267630620367257601
- jammer: 591786436651646989
- moderator: &MOD_ROLE 267629731250176001
- muted: &MUTED_ROLE 277914926603829249
- owner: &OWNER_ROLE 267627879762755584
- partners: 323426753857191936
- rockstars: &ROCKSTARS_ROLE 458226413825294336
- team_leader: 501324292341104650
- verified: 352427296948486144
+ announcements: 463658397560995840
+ contributors: 295488872404484098
+ muted: &MUTED_ROLE 277914926603829249
+ partners: 323426753857191936
+ python_community: &PY_COMMUNITY_ROLE 458226413825294336
+
+ # This is the Developers role on PyDis, here named verified for readability reasons
+ verified: 352427296948486144
+
+ # Staff
+ admins: &ADMINS_ROLE 267628507062992896
+ core_developers: 587606783669829632
+ helpers: &HELPERS_ROLE 267630620367257601
+ moderators: &MODS_ROLE 267629731250176001
+ owners: &OWNERS_ROLE 267627879762755584
+
+ # Code Jam
+ jammers: 591786436651646989
+ team_leaders: 501324292341104650
+
+ moderation_roles:
+ - *OWNERS_ROLE
+ - *ADMINS_ROLE
+ - *MODS_ROLE
+
+ staff_roles:
+ - *OWNERS_ROLE
+ - *ADMINS_ROLE
+ - *MODS_ROLE
+ - *HELPERS_ROLE
webhooks:
- talent_pool: 569145364800602132
- big_brother: 569133704568373283
- reddit: 635408384794951680
- duck_pond: 637821475327311927
- dev_log: 680501655111729222
+ talent_pool: 569145364800602132
+ big_brother: 569133704568373283
+ reddit: 635408384794951680
+ duck_pond: 637821475327311927
+ dev_log: 680501655111729222
filter:
@@ -227,6 +284,30 @@ filter:
domain_blacklist:
- pornhub.com
- liveleak.com
+ - grabify.link
+ - bmwforum.co
+ - leancoding.co
+ - spottyfly.com
+ - stopify.co
+ - yoütu.be
+ - discörd.com
+ - minecräft.com
+ - freegiftcards.co
+ - disçordapp.com
+ - fortnight.space
+ - fortnitechat.site
+ - joinmy.site
+ - curiouscat.club
+ - catsnthings.fun
+ - yourtube.site
+ - youtubeshort.watch
+ - catsnthing.com
+ - youtubeshort.pro
+ - canadianlumberjacks.online
+ - poweredbydialup.club
+ - poweredbydialup.online
+ - poweredbysecurity.org
+ - poweredbysecurity.online
word_watchlist:
- goo+ks*
@@ -260,20 +341,20 @@ filter:
# Censor doesn't apply to these
channel_whitelist:
- *ADMINS
- - *MODLOG
+ - *MOD_LOG
- *MESSAGE_LOG
- - *DEVLOG
- - *BBLOGS
+ - *DEV_LOG
+ - *BB_LOGS
- *STAFF_LOUNGE
- - *DEVTEST
- *TALENT_POOL
- *USER_EVENT_A
role_whitelist:
- - *ADMIN_ROLE
- - *MOD_ROLE
- - *OWNER_ROLE
- - *ROCKSTARS_ROLE
+ - *ADMINS_ROLE
+ - *MODS_ROLE
+ - *OWNERS_ROLE
+ - *HELPERS_ROLE
+ - *PY_COMMUNITY_ROLE
keys:
@@ -441,7 +522,20 @@ sync:
duck_pond:
threshold: 5
- custom_emojis: [*DUCKY_YELLOW, *DUCKY_BLURPLE, *DUCKY_CAMO, *DUCKY_DEVIL, *DUCKY_NINJA, *DUCKY_REGAL, *DUCKY_TUBE, *DUCKY_HUNT, *DUCKY_WIZARD, *DUCKY_PARTY, *DUCKY_ANGEL, *DUCKY_MAUL, *DUCKY_SANTA]
+ custom_emojis:
+ - *DUCKY_YELLOW
+ - *DUCKY_BLURPLE
+ - *DUCKY_CAMO
+ - *DUCKY_DEVIL
+ - *DUCKY_NINJA
+ - *DUCKY_REGAL
+ - *DUCKY_TUBE
+ - *DUCKY_HUNT
+ - *DUCKY_WIZARD
+ - *DUCKY_PARTY
+ - *DUCKY_ANGEL
+ - *DUCKY_MAUL
+ - *DUCKY_SANTA
config:
required_keys: ['bot.token']
diff --git a/tests/base.py b/tests/base.py
index 88693f382..42174e911 100644
--- a/tests/base.py
+++ b/tests/base.py
@@ -22,8 +22,13 @@ class _CaptureLogHandler(logging.Handler):
self.records.append(record)
-class LoggingTestCase(unittest.TestCase):
- """TestCase subclass that adds more logging assertion tools."""
+class LoggingTestsMixin:
+ """
+ A mixin that defines additional test methods for logging behavior.
+
+ This mixin relies on the availability of the `fail` attribute defined by the
+ test classes included in Python's unittest method to signal test failure.
+ """
@contextmanager
def assertNotLogs(self, logger=None, level=None, msg=None):
@@ -73,10 +78,9 @@ class LoggingTestCase(unittest.TestCase):
self.fail(msg)
-class CommandTestCase(unittest.TestCase):
+class CommandTestCase(unittest.IsolatedAsyncioTestCase):
"""TestCase with additional assertions that are useful for testing Discord commands."""
- @helpers.async_test
async def assertHasPermissionsCheck(
self,
cmd: commands.Command,
diff --git a/tests/bot/cogs/sync/test_base.py b/tests/bot/cogs/sync/test_base.py
index e6a6f9688..fe0594efe 100644
--- a/tests/bot/cogs/sync/test_base.py
+++ b/tests/bot/cogs/sync/test_base.py
@@ -13,8 +13,8 @@ class TestSyncer(Syncer):
"""Syncer subclass with mocks for abstract methods for testing purposes."""
name = "test"
- _get_diff = helpers.AsyncMock()
- _sync = helpers.AsyncMock()
+ _get_diff = mock.AsyncMock()
+ _sync = mock.AsyncMock()
class SyncerBaseTests(unittest.TestCase):
@@ -29,7 +29,7 @@ class SyncerBaseTests(unittest.TestCase):
Syncer(self.bot)
-class SyncerSendPromptTests(unittest.TestCase):
+class SyncerSendPromptTests(unittest.IsolatedAsyncioTestCase):
"""Tests for sending the sync confirmation prompt."""
def setUp(self):
@@ -61,7 +61,6 @@ class SyncerSendPromptTests(unittest.TestCase):
return mock_channel, mock_message
- @helpers.async_test
async def test_send_prompt_edits_and_returns_message(self):
"""The given message should be edited to display the prompt and then should be returned."""
msg = helpers.MockMessage()
@@ -71,7 +70,6 @@ class SyncerSendPromptTests(unittest.TestCase):
self.assertIn("content", msg.edit.call_args[1])
self.assertEqual(ret_val, msg)
- @helpers.async_test
async def test_send_prompt_gets_dev_core_channel(self):
"""The dev-core channel should be retrieved if an extant message isn't given."""
subtests = (
@@ -84,9 +82,8 @@ class SyncerSendPromptTests(unittest.TestCase):
mock_()
await self.syncer._send_prompt()
- method.assert_called_once_with(constants.Channels.devcore)
+ method.assert_called_once_with(constants.Channels.dev_core)
- @helpers.async_test
async def test_send_prompt_returns_None_if_channel_fetch_fails(self):
"""None should be returned if there's an HTTPException when fetching the channel."""
self.bot.get_channel.return_value = None
@@ -96,7 +93,6 @@ class SyncerSendPromptTests(unittest.TestCase):
self.assertIsNone(ret_val)
- @helpers.async_test
async def test_send_prompt_sends_and_returns_new_message_if_not_given(self):
"""A new message mentioning core devs should be sent and returned if message isn't given."""
for mock_ in (self.mock_get_channel, self.mock_fetch_channel):
@@ -108,7 +104,6 @@ class SyncerSendPromptTests(unittest.TestCase):
self.assertIn(self.syncer._CORE_DEV_MENTION, mock_channel.send.call_args[0][0])
self.assertEqual(ret_val, mock_message)
- @helpers.async_test
async def test_send_prompt_adds_reactions(self):
"""The message should have reactions for confirmation added."""
extant_message = helpers.MockMessage()
@@ -129,13 +124,13 @@ class SyncerSendPromptTests(unittest.TestCase):
mock_message.add_reaction.assert_has_calls(calls)
-class SyncerConfirmationTests(unittest.TestCase):
+class SyncerConfirmationTests(unittest.IsolatedAsyncioTestCase):
"""Tests for waiting for a sync confirmation reaction on the prompt."""
def setUp(self):
self.bot = helpers.MockBot()
self.syncer = TestSyncer(self.bot)
- self.core_dev_role = helpers.MockRole(id=constants.Roles.core_developer)
+ self.core_dev_role = helpers.MockRole(id=constants.Roles.core_developers)
@staticmethod
def get_message_reaction(emoji):
@@ -211,7 +206,6 @@ class SyncerConfirmationTests(unittest.TestCase):
ret_val = self.syncer._reaction_check(*args)
self.assertFalse(ret_val)
- @helpers.async_test
async def test_wait_for_confirmation(self):
"""The message should always be edited and only return True if the emoji is a check mark."""
subtests = (
@@ -251,14 +245,13 @@ class SyncerConfirmationTests(unittest.TestCase):
self.assertIs(actual_return, ret_val)
-class SyncerSyncTests(unittest.TestCase):
+class SyncerSyncTests(unittest.IsolatedAsyncioTestCase):
"""Tests for main function orchestrating the sync."""
def setUp(self):
self.bot = helpers.MockBot(user=helpers.MockMember(bot=True))
self.syncer = TestSyncer(self.bot)
- @helpers.async_test
async def test_sync_respects_confirmation_result(self):
"""The sync should abort if confirmation fails and continue if confirmed."""
mock_message = helpers.MockMessage()
@@ -274,7 +267,7 @@ class SyncerSyncTests(unittest.TestCase):
diff = _Diff({1, 2, 3}, {4, 5}, None)
self.syncer._get_diff.return_value = diff
- self.syncer._get_confirmation_result = helpers.AsyncMock(
+ self.syncer._get_confirmation_result = mock.AsyncMock(
return_value=(confirmed, message)
)
@@ -289,7 +282,6 @@ class SyncerSyncTests(unittest.TestCase):
else:
self.syncer._sync.assert_not_called()
- @helpers.async_test
async def test_sync_diff_size(self):
"""The diff size should be correctly calculated."""
subtests = (
@@ -303,7 +295,7 @@ class SyncerSyncTests(unittest.TestCase):
with self.subTest(size=size, diff=diff):
self.syncer._get_diff.reset_mock()
self.syncer._get_diff.return_value = diff
- self.syncer._get_confirmation_result = helpers.AsyncMock(return_value=(False, None))
+ self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None))
guild = helpers.MockGuild()
await self.syncer.sync(guild)
@@ -312,7 +304,6 @@ class SyncerSyncTests(unittest.TestCase):
self.syncer._get_confirmation_result.assert_called_once()
self.assertEqual(self.syncer._get_confirmation_result.call_args[0][0], size)
- @helpers.async_test
async def test_sync_message_edited(self):
"""The message should be edited if one was sent, even if the sync has an API error."""
subtests = (
@@ -324,7 +315,7 @@ class SyncerSyncTests(unittest.TestCase):
for message, side_effect, should_edit in subtests:
with self.subTest(message=message, side_effect=side_effect, should_edit=should_edit):
self.syncer._sync.side_effect = side_effect
- self.syncer._get_confirmation_result = helpers.AsyncMock(
+ self.syncer._get_confirmation_result = mock.AsyncMock(
return_value=(True, message)
)
@@ -335,7 +326,6 @@ class SyncerSyncTests(unittest.TestCase):
message.edit.assert_called_once()
self.assertIn("content", message.edit.call_args[1])
- @helpers.async_test
async def test_sync_confirmation_context_redirect(self):
"""If ctx is given, a new message should be sent and author should be ctx's author."""
mock_member = helpers.MockMember()
@@ -349,7 +339,10 @@ class SyncerSyncTests(unittest.TestCase):
if ctx is not None:
ctx.send.return_value = message
- self.syncer._get_confirmation_result = helpers.AsyncMock(return_value=(False, None))
+ # Make sure `_get_diff` returns a MagicMock, not an AsyncMock
+ self.syncer._get_diff.return_value = mock.MagicMock()
+
+ self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None))
guild = helpers.MockGuild()
await self.syncer.sync(guild, ctx)
@@ -362,16 +355,15 @@ class SyncerSyncTests(unittest.TestCase):
self.assertEqual(self.syncer._get_confirmation_result.call_args[0][2], message)
@mock.patch.object(constants.Sync, "max_diff", new=3)
- @helpers.async_test
async def test_confirmation_result_small_diff(self):
"""Should always return True and the given message if the diff size is too small."""
author = helpers.MockMember()
expected_message = helpers.MockMessage()
- for size in (3, 2):
+ for size in (3, 2): # pragma: no cover
with self.subTest(size=size):
- self.syncer._send_prompt = helpers.AsyncMock()
- self.syncer._wait_for_confirmation = helpers.AsyncMock()
+ self.syncer._send_prompt = mock.AsyncMock()
+ self.syncer._wait_for_confirmation = mock.AsyncMock()
coro = self.syncer._get_confirmation_result(size, author, expected_message)
result, actual_message = await coro
@@ -382,7 +374,6 @@ class SyncerSyncTests(unittest.TestCase):
self.syncer._wait_for_confirmation.assert_not_called()
@mock.patch.object(constants.Sync, "max_diff", new=3)
- @helpers.async_test
async def test_confirmation_result_large_diff(self):
"""Should return True if confirmed and False if _send_prompt fails or aborted."""
author = helpers.MockMember()
@@ -394,10 +385,10 @@ class SyncerSyncTests(unittest.TestCase):
(False, mock_message, False, "aborted"),
)
- for expected_result, expected_message, confirmed, msg in subtests:
+ for expected_result, expected_message, confirmed, msg in subtests: # pragma: no cover
with self.subTest(msg=msg):
- self.syncer._send_prompt = helpers.AsyncMock(return_value=expected_message)
- self.syncer._wait_for_confirmation = helpers.AsyncMock(return_value=confirmed)
+ self.syncer._send_prompt = mock.AsyncMock(return_value=expected_message)
+ self.syncer._wait_for_confirmation = mock.AsyncMock(return_value=confirmed)
coro = self.syncer._get_confirmation_result(4, author)
actual_result, actual_message = await coro
diff --git a/tests/bot/cogs/sync/test_cog.py b/tests/bot/cogs/sync/test_cog.py
index 98c9afc0d..81398c61f 100644
--- a/tests/bot/cogs/sync/test_cog.py
+++ b/tests/bot/cogs/sync/test_cog.py
@@ -11,19 +11,7 @@ from tests import helpers
from tests.base import CommandTestCase
-class MockSyncer(helpers.CustomMockMixin, mock.MagicMock):
- """
- A MagicMock subclass to mock Syncer objects.
-
- Instances of this class will follow the specifications of `bot.cogs.sync.syncers.Syncer`
- instances. For more information, see the `MockGuild` docstring.
- """
-
- def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=Syncer, **kwargs)
-
-
-class SyncExtensionTests(unittest.TestCase):
+class SyncExtensionTests(unittest.IsolatedAsyncioTestCase):
"""Tests for the sync extension."""
@staticmethod
@@ -34,22 +22,21 @@ class SyncExtensionTests(unittest.TestCase):
bot.add_cog.assert_called_once()
-class SyncCogTestCase(unittest.TestCase):
+class SyncCogTestCase(unittest.IsolatedAsyncioTestCase):
"""Base class for Sync cog tests. Sets up patches for syncers."""
def setUp(self):
self.bot = helpers.MockBot()
- # These patch the type. When the type is called, a MockSyncer instanced is returned.
- # MockSyncer is needed so that our custom AsyncMock is used.
- # TODO: Use autospec instead in 3.8, which will automatically use AsyncMock when needed.
self.role_syncer_patcher = mock.patch(
"bot.cogs.sync.syncers.RoleSyncer",
- new=mock.MagicMock(return_value=MockSyncer())
+ autospec=Syncer,
+ spec_set=True
)
self.user_syncer_patcher = mock.patch(
"bot.cogs.sync.syncers.UserSyncer",
- new=mock.MagicMock(return_value=MockSyncer())
+ autospec=Syncer,
+ spec_set=True
)
self.RoleSyncer = self.role_syncer_patcher.start()
self.UserSyncer = self.user_syncer_patcher.start()
@@ -72,13 +59,13 @@ class SyncCogTestCase(unittest.TestCase):
class SyncCogTests(SyncCogTestCase):
"""Tests for the Sync cog."""
- @mock.patch.object(sync.Sync, "sync_guild")
+ @mock.patch.object(sync.Sync, "sync_guild", new_callable=mock.MagicMock)
def test_sync_cog_init(self, sync_guild):
"""Should instantiate syncers and run a sync for the guild."""
# Reset because a Sync cog was already instantiated in setUp.
self.RoleSyncer.reset_mock()
self.UserSyncer.reset_mock()
- self.bot.loop.create_task.reset_mock()
+ self.bot.loop.create_task = mock.MagicMock()
mock_sync_guild_coro = mock.MagicMock()
sync_guild.return_value = mock_sync_guild_coro
@@ -90,7 +77,6 @@ class SyncCogTests(SyncCogTestCase):
sync_guild.assert_called_once_with()
self.bot.loop.create_task.assert_called_once_with(mock_sync_guild_coro)
- @helpers.async_test
async def test_sync_cog_sync_guild(self):
"""Roles and users should be synced only if a guild is successfully retrieved."""
for guild in (helpers.MockGuild(), None):
@@ -126,14 +112,12 @@ class SyncCogTests(SyncCogTestCase):
json=updated_information,
)
- @helpers.async_test
async def test_sync_cog_patch_user(self):
"""A PATCH request should be sent and 404 errors ignored."""
for side_effect in (None, self.response_error(404)):
with self.subTest(side_effect=side_effect):
await self.patch_user_helper(side_effect)
- @helpers.async_test
async def test_sync_cog_patch_user_non_404(self):
"""A PATCH request should be sent and the error raised if it's not a 404."""
with self.assertRaises(ResponseCodeError):
@@ -145,9 +129,8 @@ class SyncCogListenerTests(SyncCogTestCase):
def setUp(self):
super().setUp()
- self.cog.patch_user = helpers.AsyncMock(spec_set=self.cog.patch_user)
+ self.cog.patch_user = mock.AsyncMock(spec_set=self.cog.patch_user)
- @helpers.async_test
async def test_sync_cog_on_guild_role_create(self):
"""A POST request should be sent with the new role's data."""
self.assertTrue(self.cog.on_guild_role_create.__cog_listener__)
@@ -164,7 +147,6 @@ class SyncCogListenerTests(SyncCogTestCase):
self.bot.api_client.post.assert_called_once_with("bot/roles", json=role_data)
- @helpers.async_test
async def test_sync_cog_on_guild_role_delete(self):
"""A DELETE request should be sent."""
self.assertTrue(self.cog.on_guild_role_delete.__cog_listener__)
@@ -174,7 +156,6 @@ class SyncCogListenerTests(SyncCogTestCase):
self.bot.api_client.delete.assert_called_once_with("bot/roles/99")
- @helpers.async_test
async def test_sync_cog_on_guild_role_update(self):
"""A PUT request should be sent if the colour, name, permissions, or position changes."""
self.assertTrue(self.cog.on_guild_role_update.__cog_listener__)
@@ -212,7 +193,6 @@ class SyncCogListenerTests(SyncCogTestCase):
else:
self.bot.api_client.put.assert_not_called()
- @helpers.async_test
async def test_sync_cog_on_member_remove(self):
"""Member should patched to set in_guild as False."""
self.assertTrue(self.cog.on_member_remove.__cog_listener__)
@@ -225,7 +205,6 @@ class SyncCogListenerTests(SyncCogTestCase):
updated_information={"in_guild": False}
)
- @helpers.async_test
async def test_sync_cog_on_member_update_roles(self):
"""Members should be patched if their roles have changed."""
self.assertTrue(self.cog.on_member_update.__cog_listener__)
@@ -240,7 +219,6 @@ class SyncCogListenerTests(SyncCogTestCase):
data = {"roles": sorted(role.id for role in after_member.roles)}
self.cog.patch_user.assert_called_once_with(after_member.id, updated_information=data)
- @helpers.async_test
async def test_sync_cog_on_member_update_other(self):
"""Members should not be patched if other attributes have changed."""
self.assertTrue(self.cog.on_member_update.__cog_listener__)
@@ -262,7 +240,6 @@ class SyncCogListenerTests(SyncCogTestCase):
self.cog.patch_user.assert_not_called()
- @helpers.async_test
async def test_sync_cog_on_user_update(self):
"""A user should be patched only if the name, discriminator, or avatar changes."""
self.assertTrue(self.cog.on_user_update.__cog_listener__)
@@ -341,7 +318,6 @@ class SyncCogListenerTests(SyncCogTestCase):
return data
- @helpers.async_test
async def test_sync_cog_on_member_join(self):
"""Should PUT user's data or POST it if the user doesn't exist."""
for side_effect in (None, self.response_error(404)):
@@ -354,7 +330,6 @@ class SyncCogListenerTests(SyncCogTestCase):
else:
self.bot.api_client.post.assert_not_called()
- @helpers.async_test
async def test_sync_cog_on_member_join_non_404(self):
"""ResponseCodeError should be re-raised if status code isn't a 404."""
with self.assertRaises(ResponseCodeError):
@@ -366,7 +341,6 @@ class SyncCogListenerTests(SyncCogTestCase):
class SyncCogCommandTests(SyncCogTestCase, CommandTestCase):
"""Tests for the commands in the Sync cog."""
- @helpers.async_test
async def test_sync_roles_command(self):
"""sync() should be called on the RoleSyncer."""
ctx = helpers.MockContext()
@@ -374,7 +348,6 @@ class SyncCogCommandTests(SyncCogTestCase, CommandTestCase):
self.cog.role_syncer.sync.assert_called_once_with(ctx.guild, ctx)
- @helpers.async_test
async def test_sync_users_command(self):
"""sync() should be called on the UserSyncer."""
ctx = helpers.MockContext()
@@ -382,7 +355,7 @@ class SyncCogCommandTests(SyncCogTestCase, CommandTestCase):
self.cog.user_syncer.sync.assert_called_once_with(ctx.guild, ctx)
- def test_commands_require_admin(self):
+ async def test_commands_require_admin(self):
"""The sync commands should only run if the author has the administrator permission."""
cmds = (
self.cog.sync_group,
@@ -392,4 +365,4 @@ class SyncCogCommandTests(SyncCogTestCase, CommandTestCase):
for cmd in cmds:
with self.subTest(cmd=cmd):
- self.assertHasPermissionsCheck(cmd, {"administrator": True})
+ await self.assertHasPermissionsCheck(cmd, {"administrator": True})
diff --git a/tests/bot/cogs/sync/test_roles.py b/tests/bot/cogs/sync/test_roles.py
index 14fb2577a..79eee98f4 100644
--- a/tests/bot/cogs/sync/test_roles.py
+++ b/tests/bot/cogs/sync/test_roles.py
@@ -18,7 +18,7 @@ def fake_role(**kwargs):
return kwargs
-class RoleSyncerDiffTests(unittest.TestCase):
+class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
"""Tests for determining differences between roles in the DB and roles in the Guild cache."""
def setUp(self):
@@ -39,7 +39,6 @@ class RoleSyncerDiffTests(unittest.TestCase):
return guild
- @helpers.async_test
async def test_empty_diff_for_identical_roles(self):
"""No differences should be found if the roles in the guild and DB are identical."""
self.bot.api_client.get.return_value = [fake_role()]
@@ -50,7 +49,6 @@ class RoleSyncerDiffTests(unittest.TestCase):
self.assertEqual(actual_diff, expected_diff)
- @helpers.async_test
async def test_diff_for_updated_roles(self):
"""Only updated roles should be added to the 'updated' set of the diff."""
updated_role = fake_role(id=41, name="new")
@@ -63,7 +61,6 @@ class RoleSyncerDiffTests(unittest.TestCase):
self.assertEqual(actual_diff, expected_diff)
- @helpers.async_test
async def test_diff_for_new_roles(self):
"""Only new roles should be added to the 'created' set of the diff."""
new_role = fake_role(id=41, name="new")
@@ -76,7 +73,6 @@ class RoleSyncerDiffTests(unittest.TestCase):
self.assertEqual(actual_diff, expected_diff)
- @helpers.async_test
async def test_diff_for_deleted_roles(self):
"""Only deleted roles should be added to the 'deleted' set of the diff."""
deleted_role = fake_role(id=61, name="deleted")
@@ -89,7 +85,6 @@ class RoleSyncerDiffTests(unittest.TestCase):
self.assertEqual(actual_diff, expected_diff)
- @helpers.async_test
async def test_diff_for_new_updated_and_deleted_roles(self):
"""When roles are added, updated, and removed, all of them are returned properly."""
new = fake_role(id=41, name="new")
@@ -109,14 +104,13 @@ class RoleSyncerDiffTests(unittest.TestCase):
self.assertEqual(actual_diff, expected_diff)
-class RoleSyncerSyncTests(unittest.TestCase):
+class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase):
"""Tests for the API requests that sync roles."""
def setUp(self):
self.bot = helpers.MockBot()
self.syncer = RoleSyncer(self.bot)
- @helpers.async_test
async def test_sync_created_roles(self):
"""Only POST requests should be made with the correct payload."""
roles = [fake_role(id=111), fake_role(id=222)]
@@ -132,7 +126,6 @@ class RoleSyncerSyncTests(unittest.TestCase):
self.bot.api_client.put.assert_not_called()
self.bot.api_client.delete.assert_not_called()
- @helpers.async_test
async def test_sync_updated_roles(self):
"""Only PUT requests should be made with the correct payload."""
roles = [fake_role(id=111), fake_role(id=222)]
@@ -148,7 +141,6 @@ class RoleSyncerSyncTests(unittest.TestCase):
self.bot.api_client.post.assert_not_called()
self.bot.api_client.delete.assert_not_called()
- @helpers.async_test
async def test_sync_deleted_roles(self):
"""Only DELETE requests should be made with the correct payload."""
roles = [fake_role(id=111), fake_role(id=222)]
diff --git a/tests/bot/cogs/sync/test_users.py b/tests/bot/cogs/sync/test_users.py
index 421bf6bb6..818883012 100644
--- a/tests/bot/cogs/sync/test_users.py
+++ b/tests/bot/cogs/sync/test_users.py
@@ -17,7 +17,7 @@ def fake_user(**kwargs):
return kwargs
-class UserSyncerDiffTests(unittest.TestCase):
+class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase):
"""Tests for determining differences between users in the DB and users in the Guild cache."""
def setUp(self):
@@ -42,7 +42,6 @@ class UserSyncerDiffTests(unittest.TestCase):
return guild
- @helpers.async_test
async def test_empty_diff_for_no_users(self):
"""When no users are given, an empty diff should be returned."""
guild = self.get_guild()
@@ -52,7 +51,6 @@ class UserSyncerDiffTests(unittest.TestCase):
self.assertEqual(actual_diff, expected_diff)
- @helpers.async_test
async def test_empty_diff_for_identical_users(self):
"""No differences should be found if the users in the guild and DB are identical."""
self.bot.api_client.get.return_value = [fake_user()]
@@ -63,7 +61,6 @@ class UserSyncerDiffTests(unittest.TestCase):
self.assertEqual(actual_diff, expected_diff)
- @helpers.async_test
async def test_diff_for_updated_users(self):
"""Only updated users should be added to the 'updated' set of the diff."""
updated_user = fake_user(id=99, name="new")
@@ -76,7 +73,6 @@ class UserSyncerDiffTests(unittest.TestCase):
self.assertEqual(actual_diff, expected_diff)
- @helpers.async_test
async def test_diff_for_new_users(self):
"""Only new users should be added to the 'created' set of the diff."""
new_user = fake_user(id=99, name="new")
@@ -89,7 +85,6 @@ class UserSyncerDiffTests(unittest.TestCase):
self.assertEqual(actual_diff, expected_diff)
- @helpers.async_test
async def test_diff_sets_in_guild_false_for_leaving_users(self):
"""When a user leaves the guild, the `in_guild` flag is updated to `False`."""
leaving_user = fake_user(id=63, in_guild=False)
@@ -102,7 +97,6 @@ class UserSyncerDiffTests(unittest.TestCase):
self.assertEqual(actual_diff, expected_diff)
- @helpers.async_test
async def test_diff_for_new_updated_and_leaving_users(self):
"""When users are added, updated, and removed, all of them are returned properly."""
new_user = fake_user(id=99, name="new")
@@ -117,7 +111,6 @@ class UserSyncerDiffTests(unittest.TestCase):
self.assertEqual(actual_diff, expected_diff)
- @helpers.async_test
async def test_empty_diff_for_db_users_not_in_guild(self):
"""When the DB knows a user the guild doesn't, no difference is found."""
self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)]
@@ -129,14 +122,13 @@ class UserSyncerDiffTests(unittest.TestCase):
self.assertEqual(actual_diff, expected_diff)
-class UserSyncerSyncTests(unittest.TestCase):
+class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase):
"""Tests for the API requests that sync users."""
def setUp(self):
self.bot = helpers.MockBot()
self.syncer = UserSyncer(self.bot)
- @helpers.async_test
async def test_sync_created_users(self):
"""Only POST requests should be made with the correct payload."""
users = [fake_user(id=111), fake_user(id=222)]
@@ -152,7 +144,6 @@ class UserSyncerSyncTests(unittest.TestCase):
self.bot.api_client.put.assert_not_called()
self.bot.api_client.delete.assert_not_called()
- @helpers.async_test
async def test_sync_updated_users(self):
"""Only PUT requests should be made with the correct payload."""
users = [fake_user(id=111), fake_user(id=222)]
diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py
index 5b0a3b8c3..7e6bfc748 100644
--- a/tests/bot/cogs/test_duck_pond.py
+++ b/tests/bot/cogs/test_duck_pond.py
@@ -2,7 +2,7 @@ import asyncio
import logging
import typing
import unittest
-from unittest.mock import MagicMock, patch
+from unittest.mock import AsyncMock, MagicMock, patch
import discord
@@ -14,7 +14,7 @@ from tests import helpers
MODULE_PATH = "bot.cogs.duck_pond"
-class DuckPondTests(base.LoggingTestCase):
+class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase):
"""Tests for DuckPond functionality."""
@classmethod
@@ -88,7 +88,6 @@ class DuckPondTests(base.LoggingTestCase):
with self.subTest(user_type=user.name, expected_return=expected_return, actual_return=actual_return):
self.assertEqual(expected_return, actual_return)
- @helpers.async_test
async def test_has_green_checkmark_correctly_detects_presence_of_green_checkmark_emoji(self):
"""The `has_green_checkmark` method should only return `True` if one is present."""
test_cases = (
@@ -172,7 +171,6 @@ class DuckPondTests(base.LoggingTestCase):
nonstaffers = [helpers.MockMember() for _ in range(nonstaff)]
return helpers.MockReaction(emoji=emoji, users=staffers + nonstaffers)
- @helpers.async_test
async def test_count_ducks_correctly_counts_the_number_of_eligible_duck_emojis(self):
"""The `count_ducks` method should return the number of unique staffers who gave a duck."""
test_cases = (
@@ -280,7 +278,6 @@ class DuckPondTests(base.LoggingTestCase):
with self.subTest(test_case=description, expected_count=expected_count, actual_count=actual_count):
self.assertEqual(expected_count, actual_count)
- @helpers.async_test
async def test_relay_message_correctly_relays_content_and_attachments(self):
"""The `relay_message` method should correctly relay message content and attachments."""
send_webhook_path = f"{MODULE_PATH}.DuckPond.send_webhook"
@@ -296,8 +293,8 @@ class DuckPondTests(base.LoggingTestCase):
)
for message, expect_webhook_call, expect_attachment_call in test_values:
- with patch(send_webhook_path, new_callable=helpers.AsyncMock) as send_webhook:
- with patch(send_attachments_path, new_callable=helpers.AsyncMock) as send_attachments:
+ with patch(send_webhook_path, new_callable=AsyncMock) as send_webhook:
+ with patch(send_attachments_path, new_callable=AsyncMock) as send_attachments:
with self.subTest(clean_content=message.clean_content, attachments=message.attachments):
await self.cog.relay_message(message)
@@ -306,8 +303,7 @@ class DuckPondTests(base.LoggingTestCase):
message.add_reaction.assert_called_once_with(self.checkmark_emoji)
- @patch(f"{MODULE_PATH}.send_attachments", new_callable=helpers.AsyncMock)
- @helpers.async_test
+ @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock)
async def test_relay_message_handles_irretrievable_attachment_exceptions(self, send_attachments):
"""The `relay_message` method should handle irretrievable attachments."""
message = helpers.MockMessage(clean_content="message", attachments=["attachment"])
@@ -316,18 +312,17 @@ class DuckPondTests(base.LoggingTestCase):
self.cog.webhook = helpers.MockAsyncWebhook()
log = logging.getLogger("bot.cogs.duck_pond")
- for side_effect in side_effects:
+ for side_effect in side_effects: # pragma: no cover
send_attachments.side_effect = side_effect
- with patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=helpers.AsyncMock) as send_webhook:
+ with patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=AsyncMock) as send_webhook:
with self.subTest(side_effect=type(side_effect).__name__):
with self.assertNotLogs(logger=log, level=logging.ERROR):
await self.cog.relay_message(message)
self.assertEqual(send_webhook.call_count, 2)
- @patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=helpers.AsyncMock)
- @patch(f"{MODULE_PATH}.send_attachments", new_callable=helpers.AsyncMock)
- @helpers.async_test
+ @patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=AsyncMock)
+ @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock)
async def test_relay_message_handles_attachment_http_error(self, send_attachments, send_webhook):
"""The `relay_message` method should handle irretrievable attachments."""
message = helpers.MockMessage(clean_content="message", attachments=["attachment"])
@@ -360,7 +355,6 @@ class DuckPondTests(base.LoggingTestCase):
payload.emoji.name = emoji_name
return payload
- @helpers.async_test
async def test_payload_has_duckpond_emoji_correctly_detects_relevant_emojis(self):
"""The `on_raw_reaction_add` event handler should ignore irrelevant emojis."""
test_values = (
@@ -434,7 +428,6 @@ class DuckPondTests(base.LoggingTestCase):
return channel, message, member, payload
- @helpers.async_test
async def test_on_raw_reaction_add_returns_for_bot_and_non_staff_members(self):
"""The `on_raw_reaction_add` event handler should return for bot users or non-staff members."""
channel_id = 1234
@@ -463,7 +456,7 @@ class DuckPondTests(base.LoggingTestCase):
channel.fetch_message.reset_mock()
@patch(f"{MODULE_PATH}.DuckPond.is_staff")
- @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock)
+ @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock)
def test_on_raw_reaction_add_returns_on_message_with_green_checkmark_placed_by_bot(self, count_ducks, is_staff):
"""The `on_raw_reaction_add` event should return when the message has a green check mark placed by the bot."""
channel_id = 31415926535
@@ -485,7 +478,6 @@ class DuckPondTests(base.LoggingTestCase):
# Assert that we've made it past `self.is_staff`
is_staff.assert_called_once()
- @helpers.async_test
async def test_on_raw_reaction_add_does_not_relay_below_duck_threshold(self):
"""The `on_raw_reaction_add` listener should not relay messages or attachments below the duck threshold."""
test_cases = (
@@ -499,8 +491,8 @@ class DuckPondTests(base.LoggingTestCase):
payload.emoji = self.duck_pond_emoji
for duck_count, should_relay in test_cases:
- with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=helpers.AsyncMock) as relay_message:
- with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) as count_ducks:
+ with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=AsyncMock) as relay_message:
+ with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks:
count_ducks.return_value = duck_count
with self.subTest(duck_count=duck_count, should_relay=should_relay):
await self.cog.on_raw_reaction_add(payload)
@@ -515,7 +507,6 @@ class DuckPondTests(base.LoggingTestCase):
if should_relay:
relay_message.assert_called_once_with(message)
- @helpers.async_test
async def test_on_raw_reaction_remove_prevents_removal_of_green_checkmark_depending_on_the_duck_count(self):
"""The `on_raw_reaction_remove` listener prevents removal of the check mark on messages with enough ducks."""
checkmark = helpers.MockPartialEmoji(name=self.checkmark_emoji)
@@ -535,7 +526,7 @@ class DuckPondTests(base.LoggingTestCase):
(constants.DuckPond.threshold + 1, True),
)
for duck_count, should_re_add_checkmark in test_cases:
- with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) as count_ducks:
+ with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks:
count_ducks.return_value = duck_count
with self.subTest(duck_count=duck_count, should_re_add_checkmark=should_re_add_checkmark):
await self.cog.on_raw_reaction_remove(payload)
diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py
index deae7ebad..5693d2946 100644
--- a/tests/bot/cogs/test_information.py
+++ b/tests/bot/cogs/test_information.py
@@ -19,7 +19,7 @@ class InformationCogTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
- cls.moderator_role = helpers.MockRole(name="Moderator", id=constants.Roles.moderator)
+ cls.moderator_role = helpers.MockRole(name="Moderator", id=constants.Roles.moderators)
def setUp(self):
"""Sets up fresh objects for each test."""
@@ -34,7 +34,7 @@ class InformationCogTests(unittest.TestCase):
"""Test if the `role_info` command correctly returns the `moderator_role`."""
self.ctx.guild.roles.append(self.moderator_role)
- self.cog.roles_info.can_run = helpers.AsyncMock()
+ self.cog.roles_info.can_run = unittest.mock.AsyncMock()
self.cog.roles_info.can_run.return_value = True
coroutine = self.cog.roles_info.callback(self.cog, self.ctx)
@@ -72,7 +72,7 @@ class InformationCogTests(unittest.TestCase):
self.ctx.guild.roles.append([dummy_role, admin_role])
- self.cog.role_info.can_run = helpers.AsyncMock()
+ self.cog.role_info.can_run = unittest.mock.AsyncMock()
self.cog.role_info.can_run.return_value = True
coroutine = self.cog.role_info.callback(self.cog, self.ctx, dummy_role, admin_role)
@@ -174,7 +174,7 @@ class UserInfractionHelperMethodTests(unittest.TestCase):
def setUp(self):
"""Common set-up steps done before for each test."""
self.bot = helpers.MockBot()
- self.bot.api_client.get = helpers.AsyncMock()
+ self.bot.api_client.get = unittest.mock.AsyncMock()
self.cog = information.Information(self.bot)
self.member = helpers.MockMember(id=1234)
@@ -345,10 +345,10 @@ class UserEmbedTests(unittest.TestCase):
def setUp(self):
"""Common set-up steps done before for each test."""
self.bot = helpers.MockBot()
- self.bot.api_client.get = helpers.AsyncMock()
+ self.bot.api_client.get = unittest.mock.AsyncMock()
self.cog = information.Information(self.bot)
- @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value=""))
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))
def test_create_user_embed_uses_string_representation_of_user_in_title_if_nick_is_not_available(self):
"""The embed should use the string representation of the user if they don't have a nick."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1))
@@ -360,7 +360,7 @@ class UserEmbedTests(unittest.TestCase):
self.assertEqual(embed.title, "Mr. Hemlock")
- @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value=""))
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))
def test_create_user_embed_uses_nick_in_title_if_available(self):
"""The embed should use the nick if it's available."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1))
@@ -372,7 +372,7 @@ class UserEmbedTests(unittest.TestCase):
self.assertEqual(embed.title, "Cat lover (Mr. Hemlock)")
- @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value=""))
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))
def test_create_user_embed_ignores_everyone_role(self):
"""Created `!user` embeds should not contain mention of the @everyone-role."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1))
@@ -387,8 +387,8 @@ class UserEmbedTests(unittest.TestCase):
self.assertIn("&Admins", embed.description)
self.assertNotIn("&Everyone", embed.description)
- @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=helpers.AsyncMock)
- @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=helpers.AsyncMock)
+ @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=unittest.mock.AsyncMock)
+ @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=unittest.mock.AsyncMock)
def test_create_user_embed_expanded_information_in_moderation_channels(self, nomination_counts, infraction_counts):
"""The embed should contain expanded infractions and nomination info in mod channels."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=50))
@@ -423,7 +423,7 @@ class UserEmbedTests(unittest.TestCase):
embed.description
)
- @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=helpers.AsyncMock)
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=unittest.mock.AsyncMock)
def test_create_user_embed_basic_information_outside_of_moderation_channels(self, infraction_counts):
"""The embed should contain only basic infraction data outside of mod channels."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=100))
@@ -454,7 +454,7 @@ class UserEmbedTests(unittest.TestCase):
embed.description
)
- @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value=""))
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))
def test_create_user_embed_uses_top_role_colour_when_user_has_roles(self):
"""The embed should be created with the colour of the top role, if a top role is available."""
ctx = helpers.MockContext()
@@ -467,7 +467,7 @@ class UserEmbedTests(unittest.TestCase):
self.assertEqual(embed.colour, discord.Colour(moderators_role.colour))
- @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value=""))
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))
def test_create_user_embed_uses_blurple_colour_when_user_has_no_roles(self):
"""The embed should be created with a blurple colour if the user has no assigned roles."""
ctx = helpers.MockContext()
@@ -477,7 +477,7 @@ class UserEmbedTests(unittest.TestCase):
self.assertEqual(embed.colour, discord.Colour.blurple())
- @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value=""))
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value=""))
def test_create_user_embed_uses_png_format_of_user_avatar_as_thumbnail(self):
"""The embed thumbnail should be set to the user's avatar in `png` format."""
ctx = helpers.MockContext()
@@ -521,7 +521,7 @@ class UserCommandTests(unittest.TestCase):
"""A regular user should not be able to use this command outside of bot-commands."""
constants.MODERATION_ROLES = [self.moderator_role.id]
constants.STAFF_ROLES = [self.moderator_role.id]
- constants.Channels.bot = 50
+ constants.Channels.bot_commands = 50
ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=100))
@@ -529,11 +529,11 @@ class UserCommandTests(unittest.TestCase):
with self.assertRaises(InChannelCheckFailure, msg=msg):
asyncio.run(self.cog.user_info.callback(self.cog, ctx))
- @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock)
+ @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock)
def test_regular_user_may_use_command_in_bot_commands_channel(self, create_embed, constants):
"""A regular user should be allowed to use `!user` targeting themselves in bot-commands."""
constants.STAFF_ROLES = [self.moderator_role.id]
- constants.Channels.bot = 50
+ constants.Channels.bot_commands = 50
ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=50))
@@ -542,11 +542,11 @@ class UserCommandTests(unittest.TestCase):
create_embed.assert_called_once_with(ctx, self.author)
ctx.send.assert_called_once()
- @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock)
+ @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock)
def test_regular_user_can_explicitly_target_themselves(self, create_embed, constants):
"""A user should target itself with `!user` when a `user` argument was not provided."""
constants.STAFF_ROLES = [self.moderator_role.id]
- constants.Channels.bot = 50
+ constants.Channels.bot_commands = 50
ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=50))
@@ -555,11 +555,11 @@ class UserCommandTests(unittest.TestCase):
create_embed.assert_called_once_with(ctx, self.author)
ctx.send.assert_called_once()
- @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock)
+ @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock)
def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants):
"""Staff members should be able to bypass the bot-commands channel restriction."""
constants.STAFF_ROLES = [self.moderator_role.id]
- constants.Channels.bot = 50
+ constants.Channels.bot_commands = 50
ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=200))
@@ -568,7 +568,7 @@ class UserCommandTests(unittest.TestCase):
create_embed.assert_called_once_with(ctx, self.moderator)
ctx.send.assert_called_once()
- @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock)
+ @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock)
def test_moderators_can_target_another_member(self, create_embed, constants):
"""A moderator should be able to use `!user` targeting another user."""
constants.MODERATION_ROLES = [self.moderator_role.id]
diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py
new file mode 100644
index 000000000..9cd7f0154
--- /dev/null
+++ b/tests/bot/cogs/test_snekbox.py
@@ -0,0 +1,354 @@
+import asyncio
+import logging
+import unittest
+from unittest.mock import AsyncMock, MagicMock, Mock, call, patch
+
+from bot.cogs import snekbox
+from bot.cogs.snekbox import Snekbox
+from bot.constants import URLs
+from tests.helpers import MockBot, MockContext, MockMessage, MockReaction, MockUser
+
+
+class SnekboxTests(unittest.IsolatedAsyncioTestCase):
+ def setUp(self):
+ """Add mocked bot and cog to the instance."""
+ self.bot = MockBot()
+ self.cog = Snekbox(bot=self.bot)
+
+ async def test_post_eval(self):
+ """Post the eval code to the URLs.snekbox_eval_api endpoint."""
+ resp = MagicMock()
+ resp.json = AsyncMock(return_value="return")
+ self.bot.http_session.post().__aenter__.return_value = resp
+
+ self.assertEqual(await self.cog.post_eval("import random"), "return")
+ self.bot.http_session.post.assert_called_with(
+ URLs.snekbox_eval_api,
+ json={"input": "import random"},
+ raise_for_status=True
+ )
+ resp.json.assert_awaited_once()
+
+ async def test_upload_output_reject_too_long(self):
+ """Reject output longer than MAX_PASTE_LEN."""
+ result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LEN + 1))
+ self.assertEqual(result, "too long to upload")
+
+ async def test_upload_output(self):
+ """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint."""
+ key = "MarkDiamond"
+ resp = MagicMock()
+ resp.json = AsyncMock(return_value={"key": key})
+ self.bot.http_session.post().__aenter__.return_value = resp
+
+ self.assertEqual(
+ await self.cog.upload_output("My awesome output"),
+ URLs.paste_service.format(key=key)
+ )
+ self.bot.http_session.post.assert_called_with(
+ URLs.paste_service.format(key="documents"),
+ data="My awesome output",
+ raise_for_status=True
+ )
+
+ async def test_upload_output_gracefully_fallback_if_exception_during_request(self):
+ """Output upload gracefully fallback if the upload fail."""
+ resp = MagicMock()
+ resp.json = AsyncMock(side_effect=Exception)
+ self.bot.http_session.post().__aenter__.return_value = resp
+
+ log = logging.getLogger("bot.cogs.snekbox")
+ with self.assertLogs(logger=log, level='ERROR'):
+ await self.cog.upload_output('My awesome output!')
+
+ async def test_upload_output_gracefully_fallback_if_no_key_in_response(self):
+ """Output upload gracefully fallback if there is no key entry in the response body."""
+ self.assertEqual((await self.cog.upload_output('My awesome output!')), None)
+
+ def test_prepare_input(self):
+ cases = (
+ ('print("Hello world!")', 'print("Hello world!")', 'non-formatted'),
+ ('`print("Hello world!")`', 'print("Hello world!")', 'one line code block'),
+ ('```\nprint("Hello world!")```', 'print("Hello world!")', 'multiline code block'),
+ ('```py\nprint("Hello world!")```', 'print("Hello world!")', 'multiline python code block'),
+ )
+ for case, expected, testname in cases:
+ with self.subTest(msg=f'Extract code from {testname}.'):
+ self.assertEqual(self.cog.prepare_input(case), expected)
+
+ def test_get_results_message(self):
+ """Return error and message according to the eval result."""
+ cases = (
+ ('ERROR', None, ('Your eval job has failed', 'ERROR')),
+ ('', 128 + snekbox.SIGKILL, ('Your eval job timed out or ran out of memory', '')),
+ ('', 255, ('Your eval job has failed', 'A fatal NsJail error occurred'))
+ )
+ for stdout, returncode, expected in cases:
+ with self.subTest(stdout=stdout, returncode=returncode, expected=expected):
+ actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode})
+ self.assertEqual(actual, expected)
+
+ @patch('bot.cogs.snekbox.Signals', side_effect=ValueError)
+ def test_get_results_message_invalid_signal(self, mock_Signals: Mock):
+ self.assertEqual(
+ self.cog.get_results_message({'stdout': '', 'returncode': 127}),
+ ('Your eval job has completed with return code 127', '')
+ )
+
+ @patch('bot.cogs.snekbox.Signals')
+ def test_get_results_message_valid_signal(self, mock_Signals: Mock):
+ mock_Signals.return_value.name = 'SIGTEST'
+ self.assertEqual(
+ self.cog.get_results_message({'stdout': '', 'returncode': 127}),
+ ('Your eval job has completed with return code 127 (SIGTEST)', '')
+ )
+
+ def test_get_status_emoji(self):
+ """Return emoji according to the eval result."""
+ cases = (
+ (' ', -1, ':warning:'),
+ ('Hello world!', 0, ':white_check_mark:'),
+ ('Invalid beard size', -1, ':x:')
+ )
+ for stdout, returncode, expected in cases:
+ with self.subTest(stdout=stdout, returncode=returncode, expected=expected):
+ actual = self.cog.get_status_emoji({'stdout': stdout, 'returncode': returncode})
+ self.assertEqual(actual, expected)
+
+ async def test_format_output(self):
+ """Test output formatting."""
+ self.cog.upload_output = AsyncMock(return_value='https://testificate.com/')
+
+ too_many_lines = (
+ '001 | v\n002 | e\n003 | r\n004 | y\n005 | l\n006 | o\n'
+ '007 | n\n008 | g\n009 | b\n010 | e\n011 | a\n... (truncated - too many lines)'
+ )
+ too_long_too_many_lines = (
+ "\n".join(
+ f"{i:03d} | {line}" for i, line in enumerate(['verylongbeard' * 10] * 15, 1)
+ )[:1000] + "\n... (truncated - too long, too many lines)"
+ )
+
+ cases = (
+ ('', ('[No output]', None), 'No output'),
+ ('My awesome output', ('My awesome output', None), 'One line output'),
+ ('<@', ("<@\u200B", None), r'Convert <@ to <@\u200B'),
+ ('<!@', ("<!@\u200B", None), r'Convert <!@ to <!@\u200B'),
+ (
+ '\u202E\u202E\u202E',
+ ('Code block escape attempt detected; will not output result', None),
+ 'Detect RIGHT-TO-LEFT OVERRIDE'
+ ),
+ (
+ '\u200B\u200B\u200B',
+ ('Code block escape attempt detected; will not output result', None),
+ 'Detect ZERO WIDTH SPACE'
+ ),
+ ('long\nbeard', ('001 | long\n002 | beard', None), 'Two line output'),
+ (
+ 'v\ne\nr\ny\nl\no\nn\ng\nb\ne\na\nr\nd',
+ (too_many_lines, 'https://testificate.com/'),
+ '12 lines output'
+ ),
+ (
+ 'verylongbeard' * 100,
+ ('verylongbeard' * 76 + 'verylongbear\n... (truncated - too long)', 'https://testificate.com/'),
+ '1300 characters output'
+ ),
+ (
+ ('verylongbeard' * 10 + '\n') * 15,
+ (too_long_too_many_lines, 'https://testificate.com/'),
+ '15 lines, 1965 characters output'
+ ),
+ )
+ for case, expected, testname in cases:
+ with self.subTest(msg=testname, case=case, expected=expected):
+ self.assertEqual(await self.cog.format_output(case), expected)
+
+ async def test_eval_command_evaluate_once(self):
+ """Test the eval command procedure."""
+ ctx = MockContext()
+ response = MockMessage()
+ self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode')
+ self.cog.send_eval = AsyncMock(return_value=response)
+ self.cog.continue_eval = AsyncMock(return_value=None)
+
+ await self.cog.eval_command.callback(self.cog, ctx=ctx, code='MyAwesomeCode')
+ self.cog.prepare_input.assert_called_once_with('MyAwesomeCode')
+ self.cog.send_eval.assert_called_once_with(ctx, 'MyAwesomeFormattedCode')
+ self.cog.continue_eval.assert_called_once_with(ctx, response)
+
+ async def test_eval_command_evaluate_twice(self):
+ """Test the eval and re-eval command procedure."""
+ ctx = MockContext()
+ response = MockMessage()
+ self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode')
+ self.cog.send_eval = AsyncMock(return_value=response)
+ self.cog.continue_eval = AsyncMock()
+ self.cog.continue_eval.side_effect = ('MyAwesomeCode-2', None)
+
+ await self.cog.eval_command.callback(self.cog, ctx=ctx, code='MyAwesomeCode')
+ self.cog.prepare_input.has_calls(call('MyAwesomeCode'), call('MyAwesomeCode-2'))
+ self.cog.send_eval.assert_called_with(ctx, 'MyAwesomeFormattedCode')
+ self.cog.continue_eval.assert_called_with(ctx, response)
+
+ async def test_eval_command_reject_two_eval_at_the_same_time(self):
+ """Test if the eval command rejects an eval if the author already have a running eval."""
+ ctx = MockContext()
+ ctx.author.id = 42
+ ctx.author.mention = '@LemonLemonishBeard#0042'
+ ctx.send = AsyncMock()
+ self.cog.jobs = (42,)
+ await self.cog.eval_command.callback(self.cog, ctx=ctx, code='MyAwesomeCode')
+ ctx.send.assert_called_once_with(
+ "@LemonLemonishBeard#0042 You've already got a job running - please wait for it to finish!"
+ )
+
+ async def test_eval_command_call_help(self):
+ """Test if the eval command call the help command if no code is provided."""
+ ctx = MockContext()
+ ctx.invoke = AsyncMock()
+ await self.cog.eval_command.callback(self.cog, ctx=ctx, code='')
+ ctx.invoke.assert_called_once_with(self.bot.get_command("help"), "eval")
+
+ async def test_send_eval(self):
+ """Test the send_eval function."""
+ ctx = MockContext()
+ ctx.message = MockMessage()
+ ctx.send = AsyncMock()
+ ctx.author.mention = '@LemonLemonishBeard#0042'
+
+ self.cog.post_eval = AsyncMock(return_value={'stdout': '', 'returncode': 0})
+ self.cog.get_results_message = MagicMock(return_value=('Return code 0', ''))
+ self.cog.get_status_emoji = MagicMock(return_value=':yay!:')
+ self.cog.format_output = AsyncMock(return_value=('[No output]', None))
+
+ await self.cog.send_eval(ctx, 'MyAwesomeCode')
+ ctx.send.assert_called_once_with(
+ '@LemonLemonishBeard#0042 :yay!: Return code 0.\n\n```py\n[No output]\n```'
+ )
+ self.cog.post_eval.assert_called_once_with('MyAwesomeCode')
+ self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0})
+ self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0})
+ self.cog.format_output.assert_called_once_with('')
+
+ async def test_send_eval_with_paste_link(self):
+ """Test the send_eval function with a too long output that generate a paste link."""
+ ctx = MockContext()
+ ctx.message = MockMessage()
+ ctx.send = AsyncMock()
+ ctx.author.mention = '@LemonLemonishBeard#0042'
+
+ self.cog.post_eval = AsyncMock(return_value={'stdout': 'Way too long beard', 'returncode': 0})
+ self.cog.get_results_message = MagicMock(return_value=('Return code 0', ''))
+ self.cog.get_status_emoji = MagicMock(return_value=':yay!:')
+ self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com'))
+
+ await self.cog.send_eval(ctx, 'MyAwesomeCode')
+ ctx.send.assert_called_once_with(
+ '@LemonLemonishBeard#0042 :yay!: Return code 0.'
+ '\n\n```py\nWay too long beard\n```\nFull output: lookatmybeard.com'
+ )
+ self.cog.post_eval.assert_called_once_with('MyAwesomeCode')
+ self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0})
+ self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0})
+ self.cog.format_output.assert_called_once_with('Way too long beard')
+
+ async def test_send_eval_with_non_zero_eval(self):
+ """Test the send_eval function with a code returning a non-zero code."""
+ ctx = MockContext()
+ ctx.message = MockMessage()
+ ctx.send = AsyncMock()
+ ctx.author.mention = '@LemonLemonishBeard#0042'
+ self.cog.post_eval = AsyncMock(return_value={'stdout': 'ERROR', 'returncode': 127})
+ self.cog.get_results_message = MagicMock(return_value=('Return code 127', 'Beard got stuck in the eval'))
+ self.cog.get_status_emoji = MagicMock(return_value=':nope!:')
+ self.cog.format_output = AsyncMock() # This function isn't called
+
+ await self.cog.send_eval(ctx, 'MyAwesomeCode')
+ ctx.send.assert_called_once_with(
+ '@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```py\nBeard got stuck in the eval\n```'
+ )
+ self.cog.post_eval.assert_called_once_with('MyAwesomeCode')
+ self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127})
+ self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127})
+ self.cog.format_output.assert_not_called()
+
+ @patch("bot.cogs.snekbox.partial")
+ async def test_continue_eval_does_continue(self, partial_mock):
+ """Test that the continue_eval function does continue if required conditions are met."""
+ ctx = MockContext(message=MockMessage(add_reaction=AsyncMock(), clear_reactions=AsyncMock()))
+ response = MockMessage(delete=AsyncMock())
+ new_msg = MockMessage(content='!e NewCode')
+ self.bot.wait_for.side_effect = ((None, new_msg), None)
+
+ actual = await self.cog.continue_eval(ctx, response)
+ self.assertEqual(actual, 'NewCode')
+ self.bot.wait_for.assert_has_awaits(
+ (
+ call('message_edit', check=partial_mock(snekbox.predicate_eval_message_edit, ctx), timeout=10),
+ call('reaction_add', check=partial_mock(snekbox.predicate_eval_emoji_reaction, ctx), timeout=10)
+ )
+ )
+ ctx.message.add_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI)
+ ctx.message.clear_reactions.assert_called_once()
+ response.delete.assert_called_once()
+
+ async def test_continue_eval_does_not_continue(self):
+ ctx = MockContext(message=MockMessage(clear_reactions=AsyncMock()))
+ self.bot.wait_for.side_effect = asyncio.TimeoutError
+
+ actual = await self.cog.continue_eval(ctx, MockMessage())
+ self.assertEqual(actual, None)
+ ctx.message.clear_reactions.assert_called_once()
+
+ def test_predicate_eval_message_edit(self):
+ """Test the predicate_eval_message_edit function."""
+ msg0 = MockMessage(id=1, content='abc')
+ msg1 = MockMessage(id=2, content='abcdef')
+ msg2 = MockMessage(id=1, content='abcdef')
+
+ cases = (
+ (msg0, msg0, False, 'same ID, same content'),
+ (msg0, msg1, False, 'different ID, different content'),
+ (msg0, msg2, True, 'same ID, different content')
+ )
+ for ctx_msg, new_msg, expected, testname in cases:
+ with self.subTest(msg=f'Messages with {testname} return {expected}'):
+ ctx = MockContext(message=ctx_msg)
+ actual = snekbox.predicate_eval_message_edit(ctx, ctx_msg, new_msg)
+ self.assertEqual(actual, expected)
+
+ def test_predicate_eval_emoji_reaction(self):
+ """Test the predicate_eval_emoji_reaction function."""
+ valid_reaction = MockReaction(message=MockMessage(id=1))
+ valid_reaction.__str__.return_value = snekbox.REEVAL_EMOJI
+ valid_ctx = MockContext(message=MockMessage(id=1), author=MockUser(id=2))
+ valid_user = MockUser(id=2)
+
+ invalid_reaction_id = MockReaction(message=MockMessage(id=42))
+ invalid_reaction_id.__str__.return_value = snekbox.REEVAL_EMOJI
+ invalid_user_id = MockUser(id=42)
+ invalid_reaction_str = MockReaction(message=MockMessage(id=1))
+ invalid_reaction_str.__str__.return_value = ':longbeard:'
+
+ cases = (
+ (invalid_reaction_id, valid_user, False, 'invalid reaction ID'),
+ (valid_reaction, invalid_user_id, False, 'invalid user ID'),
+ (invalid_reaction_str, valid_user, False, 'invalid reaction __str__'),
+ (valid_reaction, valid_user, True, 'matching attributes')
+ )
+ for reaction, user, expected, testname in cases:
+ with self.subTest(msg=f'Test with {testname} and expected return {expected}'):
+ actual = snekbox.predicate_eval_emoji_reaction(valid_ctx, reaction, user)
+ self.assertEqual(actual, expected)
+
+
+class SnekboxSetupTests(unittest.TestCase):
+ """Tests setup of the `Snekbox` cog."""
+
+ def test_setup(self):
+ """Setup of the extension should call add_cog."""
+ bot = MockBot()
+ snekbox.setup(bot)
+ bot.add_cog.assert_called_once()
diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py
index a54b839d7..33d1ec170 100644
--- a/tests/bot/cogs/test_token_remover.py
+++ b/tests/bot/cogs/test_token_remover.py
@@ -1,7 +1,7 @@
import asyncio
import logging
import unittest
-from unittest.mock import MagicMock
+from unittest.mock import AsyncMock, MagicMock
from discord import Colour
@@ -11,7 +11,7 @@ from bot.cogs.token_remover import (
setup as setup_cog,
)
from bot.constants import Channels, Colours, Event, Icons
-from tests.helpers import AsyncMock, MockBot, MockMessage
+from tests.helpers import MockBot, MockMessage
class TokenRemoverTests(unittest.TestCase):
diff --git a/tests/bot/rules/__init__.py b/tests/bot/rules/__init__.py
index 36c986fe1..0d570f5a3 100644
--- a/tests/bot/rules/__init__.py
+++ b/tests/bot/rules/__init__.py
@@ -12,7 +12,7 @@ class DisallowedCase(NamedTuple):
n_violations: int
-class RuleTest(unittest.TestCase, metaclass=ABCMeta):
+class RuleTest(unittest.IsolatedAsyncioTestCase, metaclass=ABCMeta):
"""
Abstract class for antispam rule test cases.
@@ -68,9 +68,9 @@ class RuleTest(unittest.TestCase, metaclass=ABCMeta):
@abstractmethod
def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
"""Give expected relevant messages for `case`."""
- raise NotImplementedError
+ raise NotImplementedError # pragma: no cover
@abstractmethod
def get_report(self, case: DisallowedCase) -> str:
"""Give expected error report for `case`."""
- raise NotImplementedError
+ raise NotImplementedError # pragma: no cover
diff --git a/tests/bot/rules/test_attachments.py b/tests/bot/rules/test_attachments.py
index e54b4b5b8..d7e779221 100644
--- a/tests/bot/rules/test_attachments.py
+++ b/tests/bot/rules/test_attachments.py
@@ -2,7 +2,7 @@ from typing import Iterable
from bot.rules import attachments
from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage, async_test
+from tests.helpers import MockMessage
def make_msg(author: str, total_attachments: int) -> MockMessage:
@@ -17,7 +17,6 @@ class AttachmentRuleTests(RuleTest):
self.apply = attachments.apply
self.config = {"max": 5, "interval": 10}
- @async_test
async def test_allows_messages_without_too_many_attachments(self):
"""Messages without too many attachments are allowed as-is."""
cases = (
@@ -28,7 +27,6 @@ class AttachmentRuleTests(RuleTest):
await self.run_allowed(cases)
- @async_test
async def test_disallows_messages_with_too_many_attachments(self):
"""Messages with too many attachments trigger the rule."""
cases = (
diff --git a/tests/bot/rules/test_burst.py b/tests/bot/rules/test_burst.py
index 72f0be0c7..03682966b 100644
--- a/tests/bot/rules/test_burst.py
+++ b/tests/bot/rules/test_burst.py
@@ -2,7 +2,7 @@ from typing import Iterable
from bot.rules import burst
from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage, async_test
+from tests.helpers import MockMessage
def make_msg(author: str) -> MockMessage:
@@ -21,7 +21,6 @@ class BurstRuleTests(RuleTest):
self.apply = burst.apply
self.config = {"max": 2, "interval": 10}
- @async_test
async def test_allows_messages_within_limit(self):
"""Cases which do not violate the rule."""
cases = (
@@ -31,7 +30,6 @@ class BurstRuleTests(RuleTest):
await self.run_allowed(cases)
- @async_test
async def test_disallows_messages_beyond_limit(self):
"""Cases where the amount of messages exceeds the limit, triggering the rule."""
cases = (
diff --git a/tests/bot/rules/test_burst_shared.py b/tests/bot/rules/test_burst_shared.py
index 47367a5f8..3275143d5 100644
--- a/tests/bot/rules/test_burst_shared.py
+++ b/tests/bot/rules/test_burst_shared.py
@@ -2,7 +2,7 @@ from typing import Iterable
from bot.rules import burst_shared
from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage, async_test
+from tests.helpers import MockMessage
def make_msg(author: str) -> MockMessage:
@@ -21,7 +21,6 @@ class BurstSharedRuleTests(RuleTest):
self.apply = burst_shared.apply
self.config = {"max": 2, "interval": 10}
- @async_test
async def test_allows_messages_within_limit(self):
"""
Cases that do not violate the rule.
@@ -34,7 +33,6 @@ class BurstSharedRuleTests(RuleTest):
await self.run_allowed(cases)
- @async_test
async def test_disallows_messages_beyond_limit(self):
"""Cases where the amount of messages exceeds the limit, triggering the rule."""
cases = (
diff --git a/tests/bot/rules/test_chars.py b/tests/bot/rules/test_chars.py
index 7cc36f49e..f1e3c76a7 100644
--- a/tests/bot/rules/test_chars.py
+++ b/tests/bot/rules/test_chars.py
@@ -2,7 +2,7 @@ from typing import Iterable
from bot.rules import chars
from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage, async_test
+from tests.helpers import MockMessage
def make_msg(author: str, n_chars: int) -> MockMessage:
@@ -20,7 +20,6 @@ class CharsRuleTests(RuleTest):
"interval": 10,
}
- @async_test
async def test_allows_messages_within_limit(self):
"""Cases with a total amount of chars within limit."""
cases = (
@@ -31,7 +30,6 @@ class CharsRuleTests(RuleTest):
await self.run_allowed(cases)
- @async_test
async def test_disallows_messages_beyond_limit(self):
"""Cases where the total amount of chars exceeds the limit, triggering the rule."""
cases = (
diff --git a/tests/bot/rules/test_discord_emojis.py b/tests/bot/rules/test_discord_emojis.py
index 0239b0b00..9a72723e2 100644
--- a/tests/bot/rules/test_discord_emojis.py
+++ b/tests/bot/rules/test_discord_emojis.py
@@ -2,7 +2,7 @@ from typing import Iterable
from bot.rules import discord_emojis
from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage, async_test
+from tests.helpers import MockMessage
discord_emoji = "<:abcd:1234>" # Discord emojis follow the format <:name:id>
@@ -19,7 +19,6 @@ class DiscordEmojisRuleTests(RuleTest):
self.apply = discord_emojis.apply
self.config = {"max": 2, "interval": 10}
- @async_test
async def test_allows_messages_within_limit(self):
"""Cases with a total amount of discord emojis within limit."""
cases = (
@@ -29,7 +28,6 @@ class DiscordEmojisRuleTests(RuleTest):
await self.run_allowed(cases)
- @async_test
async def test_disallows_messages_beyond_limit(self):
"""Cases with more than the allowed amount of discord emojis."""
cases = (
diff --git a/tests/bot/rules/test_duplicates.py b/tests/bot/rules/test_duplicates.py
index 59e0fb6ef..9bd886a77 100644
--- a/tests/bot/rules/test_duplicates.py
+++ b/tests/bot/rules/test_duplicates.py
@@ -2,7 +2,7 @@ from typing import Iterable
from bot.rules import duplicates
from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage, async_test
+from tests.helpers import MockMessage
def make_msg(author: str, content: str) -> MockMessage:
@@ -17,7 +17,6 @@ class DuplicatesRuleTests(RuleTest):
self.apply = duplicates.apply
self.config = {"max": 2, "interval": 10}
- @async_test
async def test_allows_messages_within_limit(self):
"""Cases which do not violate the rule."""
cases = (
@@ -28,7 +27,6 @@ class DuplicatesRuleTests(RuleTest):
await self.run_allowed(cases)
- @async_test
async def test_disallows_messages_beyond_limit(self):
"""Cases with too many duplicate messages from the same author."""
cases = (
diff --git a/tests/bot/rules/test_links.py b/tests/bot/rules/test_links.py
index 3c3f90e5f..b091bd9d7 100644
--- a/tests/bot/rules/test_links.py
+++ b/tests/bot/rules/test_links.py
@@ -2,7 +2,7 @@ from typing import Iterable
from bot.rules import links
from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage, async_test
+from tests.helpers import MockMessage
def make_msg(author: str, total_links: int) -> MockMessage:
@@ -21,7 +21,6 @@ class LinksTests(RuleTest):
"interval": 10
}
- @async_test
async def test_links_within_limit(self):
"""Messages with an allowed amount of links."""
cases = (
@@ -34,7 +33,6 @@ class LinksTests(RuleTest):
await self.run_allowed(cases)
- @async_test
async def test_links_exceeding_limit(self):
"""Messages with a a higher than allowed amount of links."""
cases = (
diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py
index ebcdabac6..6444532f2 100644
--- a/tests/bot/rules/test_mentions.py
+++ b/tests/bot/rules/test_mentions.py
@@ -2,7 +2,7 @@ from typing import Iterable
from bot.rules import mentions
from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage, async_test
+from tests.helpers import MockMessage
def make_msg(author: str, total_mentions: int) -> MockMessage:
@@ -20,7 +20,6 @@ class TestMentions(RuleTest):
"interval": 10,
}
- @async_test
async def test_mentions_within_limit(self):
"""Messages with an allowed amount of mentions."""
cases = (
@@ -32,7 +31,6 @@ class TestMentions(RuleTest):
await self.run_allowed(cases)
- @async_test
async def test_mentions_exceeding_limit(self):
"""Messages with a higher than allowed amount of mentions."""
cases = (
diff --git a/tests/bot/rules/test_newlines.py b/tests/bot/rules/test_newlines.py
index d61c4609d..e35377773 100644
--- a/tests/bot/rules/test_newlines.py
+++ b/tests/bot/rules/test_newlines.py
@@ -2,7 +2,7 @@ from typing import Iterable, List
from bot.rules import newlines
from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage, async_test
+from tests.helpers import MockMessage
def make_msg(author: str, newline_groups: List[int]) -> MockMessage:
@@ -29,7 +29,6 @@ class TotalNewlinesRuleTests(RuleTest):
"interval": 10,
}
- @async_test
async def test_allows_messages_within_limit(self):
"""Cases which do not violate the rule."""
cases = (
@@ -41,7 +40,6 @@ class TotalNewlinesRuleTests(RuleTest):
await self.run_allowed(cases)
- @async_test
async def test_disallows_messages_total(self):
"""Cases which violate the rule by having too many newlines in total."""
cases = (
@@ -79,7 +77,6 @@ class GroupNewlinesRuleTests(RuleTest):
self.apply = newlines.apply
self.config = {"max": 5, "max_consecutive": 3, "interval": 10}
- @async_test
async def test_disallows_messages_consecutive(self):
"""Cases which violate the rule due to having too many consecutive newlines."""
cases = (
diff --git a/tests/bot/rules/test_role_mentions.py b/tests/bot/rules/test_role_mentions.py
index b339cccf7..26c05d527 100644
--- a/tests/bot/rules/test_role_mentions.py
+++ b/tests/bot/rules/test_role_mentions.py
@@ -2,7 +2,7 @@ from typing import Iterable
from bot.rules import role_mentions
from tests.bot.rules import DisallowedCase, RuleTest
-from tests.helpers import MockMessage, async_test
+from tests.helpers import MockMessage
def make_msg(author: str, n_mentions: int) -> MockMessage:
@@ -17,7 +17,6 @@ class RoleMentionsRuleTests(RuleTest):
self.apply = role_mentions.apply
self.config = {"max": 2, "interval": 10}
- @async_test
async def test_allows_messages_within_limit(self):
"""Cases with a total amount of role mentions within limit."""
cases = (
@@ -27,7 +26,6 @@ class RoleMentionsRuleTests(RuleTest):
await self.run_allowed(cases)
- @async_test
async def test_disallows_messages_beyond_limit(self):
"""Cases with more than the allowed amount of role mentions."""
cases = (
diff --git a/tests/bot/test_api.py b/tests/bot/test_api.py
index bdfcc73e4..99e942813 100644
--- a/tests/bot/test_api.py
+++ b/tests/bot/test_api.py
@@ -2,10 +2,9 @@ import unittest
from unittest.mock import MagicMock
from bot import api
-from tests.helpers import async_test
-class APIClientTests(unittest.TestCase):
+class APIClientTests(unittest.IsolatedAsyncioTestCase):
"""Tests for the bot's API client."""
@classmethod
@@ -18,7 +17,6 @@ class APIClientTests(unittest.TestCase):
"""The event loop should not be running by default."""
self.assertFalse(api.loop_is_running())
- @async_test
async def test_loop_is_running_in_async_context(self):
"""The event loop should be running in an async context."""
self.assertTrue(api.loop_is_running())
diff --git a/tests/bot/test_converters.py b/tests/bot/test_converters.py
index b2b78d9dd..1e5ca62ae 100644
--- a/tests/bot/test_converters.py
+++ b/tests/bot/test_converters.py
@@ -68,7 +68,7 @@ class ConverterTests(unittest.TestCase):
('👋', "Don't be ridiculous, you can't use that character!"),
('', "Tag names should not be empty, or filled with whitespace."),
(' ', "Tag names should not be empty, or filled with whitespace."),
- ('42', "Tag names can't be numbers."),
+ ('42', "Tag names must contain at least one letter."),
('x' * 128, "Are you insane? That's way too long!"),
)
diff --git a/tests/bot/test_utils.py b/tests/bot/test_utils.py
index 58ae2a81a..d7bcc3ba6 100644
--- a/tests/bot/test_utils.py
+++ b/tests/bot/test_utils.py
@@ -35,18 +35,3 @@ class CaseInsensitiveDictTests(unittest.TestCase):
instance = utils.CaseInsensitiveDict()
instance.update({'FOO': 'bar'})
self.assertEqual(instance['foo'], 'bar')
-
-
-class ChunkTests(unittest.TestCase):
- """Tests the `chunk` method."""
-
- def test_empty_chunking(self):
- """Tests chunking on an empty iterable."""
- generator = utils.chunks(iterable=[], size=5)
- self.assertEqual(list(generator), [])
-
- def test_list_chunking(self):
- """Tests chunking a non-empty list."""
- iterable = [1, 2, 3, 4, 5]
- generator = utils.chunks(iterable=iterable, size=2)
- self.assertEqual(list(generator), [[1, 2], [3, 4], [5]])
diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py
index 69f35f2f5..694d3a40f 100644
--- a/tests/bot/utils/test_time.py
+++ b/tests/bot/utils/test_time.py
@@ -1,12 +1,11 @@
import asyncio
import unittest
from datetime import datetime, timezone
-from unittest.mock import patch
+from unittest.mock import AsyncMock, patch
from dateutil.relativedelta import relativedelta
from bot.utils import time
-from tests.helpers import AsyncMock
class TimeTests(unittest.TestCase):
@@ -44,7 +43,7 @@ class TimeTests(unittest.TestCase):
for max_units in test_cases:
with self.subTest(max_units=max_units), self.assertRaises(ValueError) as error:
time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units)
- self.assertEqual(str(error), 'max_units must be positive')
+ self.assertEqual(str(error.exception), 'max_units must be positive')
def test_parse_rfc1123(self):
"""Testing parse_rfc1123."""
diff --git a/tests/helpers.py b/tests/helpers.py
index 9d9dd5da6..8e13f0f28 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -1,13 +1,10 @@
from __future__ import annotations
-import asyncio
import collections
-import functools
-import inspect
import itertools
import logging
import unittest.mock
-from typing import Any, Iterable, Optional
+from typing import Iterable, Optional
import discord
from discord.ext.commands import Context
@@ -26,21 +23,6 @@ for logger in logging.Logger.manager.loggerDict.values():
logger.setLevel(logging.CRITICAL)
-def async_test(wrapped):
- """
- Run a test case via asyncio.
- Example:
- >>> @async_test
- ... async def lemon_wins():
- ... assert True
- """
-
- @functools.wraps(wrapped)
- def wrapper(*args, **kwargs):
- return asyncio.run(wrapped(*args, **kwargs))
- return wrapper
-
-
class HashableMixin(discord.mixins.EqualityComparable):
"""
Mixin that provides similar hashing and equality functionality as discord.py's `Hashable` mixin.
@@ -69,24 +51,31 @@ class CustomMockMixin:
"""
Provides common functionality for our custom Mock types.
- The cooperative `__init__` automatically creates `AsyncMock` attributes for every coroutine
- function `inspect` detects in the `spec` instance we provide. In addition, this mixin takes care
- of making sure child mocks are instantiated with the correct class. By default, the mock of the
- children will be `unittest.mock.MagicMock`, but this can be overwritten by setting the attribute
- `child_mock_type` on the custom mock inheriting from this mixin.
+ The `_get_child_mock` method automatically returns an AsyncMock for coroutine methods of the mock
+ object. As discord.py also uses synchronous methods that nonetheless return coroutine objects, the
+ class attribute `additional_spec_asyncs` can be overwritten with an iterable containing additional
+ attribute names that should also mocked with an AsyncMock instead of a regular MagicMock/Mock. The
+ class method `spec_set` can be overwritten with the object that should be uses as the specification
+ for the mock.
+
+ Mock/MagicMock subclasses that use this mixin only need to define `__init__` method if they need to
+ implement custom behavior.
"""
child_mock_type = unittest.mock.MagicMock
discord_id = itertools.count(0)
+ spec_set = None
+ additional_spec_asyncs = None
- def __init__(self, spec_set: Any = None, **kwargs):
+ def __init__(self, **kwargs):
name = kwargs.pop('name', None) # `name` has special meaning for Mock classes, so we need to set it manually.
- super().__init__(spec_set=spec_set, **kwargs)
+ super().__init__(spec_set=self.spec_set, **kwargs)
+
+ if self.additional_spec_asyncs:
+ self._spec_asyncs.extend(self.additional_spec_asyncs)
if name:
self.name = name
- if spec_set:
- self._extract_coroutine_methods_from_spec_instance(spec_set)
def _get_child_mock(self, **kw):
"""
@@ -100,7 +89,16 @@ class CustomMockMixin:
This override will look for an attribute called `child_mock_type` and use that as the type of the child mock.
"""
- klass = self.child_mock_type
+ _new_name = kw.get("_new_name")
+ if _new_name in self.__dict__['_spec_asyncs']:
+ return unittest.mock.AsyncMock(**kw)
+
+ _type = type(self)
+ if issubclass(_type, unittest.mock.MagicMock) and _new_name in unittest.mock._async_method_magics:
+ # Any asynchronous magic becomes an AsyncMock
+ klass = unittest.mock.AsyncMock
+ else:
+ klass = self.child_mock_type
if self._mock_sealed:
attribute = "." + kw["name"] if "name" in kw else "()"
@@ -109,95 +107,6 @@ class CustomMockMixin:
return klass(**kw)
- def _extract_coroutine_methods_from_spec_instance(self, source: Any) -> None:
- """Automatically detect coroutine functions in `source` and set them as AsyncMock attributes."""
- for name, _method in inspect.getmembers(source, inspect.iscoroutinefunction):
- setattr(self, name, AsyncMock())
-
-
-# TODO: Remove me in Python 3.8
-class AsyncMock(CustomMockMixin, unittest.mock.MagicMock):
- """
- A MagicMock subclass to mock async callables.
-
- Python 3.8 will introduce an AsyncMock class in the standard library that will have some more
- features; this stand-in only overwrites the `__call__` method to an async version.
- """
-
- async def __call__(self, *args, **kwargs):
- return super().__call__(*args, **kwargs)
-
-
-class AsyncIteratorMock:
- """
- A class to mock asynchronous iterators.
-
- This allows async for, which is used in certain Discord.py objects. For example,
- an async iterator is returned by the Reaction.users() method.
- """
-
- def __init__(self, iterable: Iterable = None):
- if iterable is None:
- iterable = []
-
- self.iter = iter(iterable)
- self.iterable = iterable
-
- self.call_count = 0
-
- def __aiter__(self):
- return self
-
- async def __anext__(self):
- try:
- return next(self.iter)
- except StopIteration:
- raise StopAsyncIteration
-
- def __call__(self):
- """
- Keeps track of the number of times an instance has been called.
-
- This is useful, since it typically shows that the iterator has actually been used somewhere after we have
- instantiated the mock for an attribute that normally returns an iterator when called.
- """
- self.call_count += 1
- return self
-
- @property
- def return_value(self):
- """Makes `self.iterable` accessible as self.return_value."""
- return self.iterable
-
- @return_value.setter
- def return_value(self, iterable):
- """Stores the `return_value` as `self.iterable` and its iterator as `self.iter`."""
- self.iter = iter(iterable)
- self.iterable = iterable
-
- def assert_called(self):
- """Asserts if the AsyncIteratorMock instance has been called at least once."""
- if self.call_count == 0:
- raise AssertionError("Expected AsyncIteratorMock to have been called.")
-
- def assert_called_once(self):
- """Asserts if the AsyncIteratorMock instance has been called exactly once."""
- if self.call_count != 1:
- raise AssertionError(
- f"Expected AsyncIteratorMock to have been called once. Called {self.call_count} times."
- )
-
- def assert_not_called(self):
- """Asserts if the AsyncIteratorMock instance has not been called."""
- if self.call_count != 0:
- raise AssertionError(
- f"Expected AsyncIteratorMock to not have been called once. Called {self.call_count} times."
- )
-
- def reset_mock(self):
- """Resets the call count, but not the return value or iterator."""
- self.call_count = 0
-
# Create a guild instance to get a realistic Mock of `discord.Guild`
guild_data = {
@@ -248,9 +157,11 @@ class MockGuild(CustomMockMixin, unittest.mock.Mock, HashableMixin):
For more info, see the `Mocking` section in `tests/README.md`.
"""
+ spec_set = guild_instance
+
def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None:
default_kwargs = {'id': next(self.discord_id), 'members': []}
- super().__init__(spec_set=guild_instance, **collections.ChainMap(kwargs, default_kwargs))
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
self.roles = [MockRole(name="@everyone", position=1, id=0)]
if roles:
@@ -269,6 +180,8 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):
Instances of this class will follow the specifications of `discord.Role` instances. For more
information, see the `MockGuild` docstring.
"""
+ spec_set = role_instance
+
def __init__(self, **kwargs) -> None:
default_kwargs = {
'id': next(self.discord_id),
@@ -277,7 +190,7 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):
'colour': discord.Colour(0xdeadbf),
'permissions': discord.Permissions(),
}
- super().__init__(spec_set=role_instance, **collections.ChainMap(kwargs, default_kwargs))
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
if isinstance(self.colour, int):
self.colour = discord.Colour(self.colour)
@@ -306,9 +219,11 @@ class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin
Instances of this class will follow the specifications of `discord.Member` instances. For more
information, see the `MockGuild` docstring.
"""
+ spec_set = member_instance
+
def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None:
default_kwargs = {'name': 'member', 'id': next(self.discord_id), 'bot': False}
- super().__init__(spec_set=member_instance, **collections.ChainMap(kwargs, default_kwargs))
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
self.roles = [MockRole(name="@everyone", position=1, id=0)]
if roles:
@@ -329,9 +244,11 @@ class MockUser(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):
Instances of this class will follow the specifications of `discord.User` instances. For more
information, see the `MockGuild` docstring.
"""
+ spec_set = user_instance
+
def __init__(self, **kwargs) -> None:
default_kwargs = {'name': 'user', 'id': next(self.discord_id), 'bot': False}
- super().__init__(spec_set=user_instance, **collections.ChainMap(kwargs, default_kwargs))
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
if 'mention' not in kwargs:
self.mention = f"@{self.name}"
@@ -344,9 +261,7 @@ class MockAPIClient(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `bot.api.APIClient` instances.
For more information, see the `MockGuild` docstring.
"""
-
- def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=APIClient, **kwargs)
+ spec_set = APIClient
# Create a Bot instance to get a realistic MagicMock of `discord.ext.commands.Bot`
@@ -362,16 +277,13 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances.
For more information, see the `MockGuild` docstring.
"""
+ spec_set = bot_instance
+ additional_spec_asyncs = ("wait_for",)
def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=bot_instance, **kwargs)
+ super().__init__(**kwargs)
self.api_client = MockAPIClient()
- # self.wait_for is *not* a coroutine function, but returns a coroutine nonetheless and
- # and should therefore be awaited. (The documentation calls it a coroutine as well, which
- # is technically incorrect, since it's a regular def.)
- self.wait_for = AsyncMock()
-
# Since calling `create_task` on our MockBot does not actually schedule the coroutine object
# as a task in the asyncio loop, this `side_effect` calls `close()` on the coroutine object
# to prevent "has not been awaited"-warnings.
@@ -401,10 +313,11 @@ class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):
Instances of this class will follow the specifications of `discord.TextChannel` instances. For
more information, see the `MockGuild` docstring.
"""
+ spec_set = channel_instance
def __init__(self, name: str = 'channel', channel_id: int = 1, **kwargs) -> None:
default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()}
- super().__init__(spec_set=channel_instance, **collections.ChainMap(kwargs, default_kwargs))
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
if 'mention' not in kwargs:
self.mention = f"#{self.name}"
@@ -443,9 +356,10 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.ext.commands.Context`
instances. For more information, see the `MockGuild` docstring.
"""
+ spec_set = context_instance
def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=context_instance, **kwargs)
+ super().__init__(**kwargs)
self.bot = kwargs.get('bot', MockBot())
self.guild = kwargs.get('guild', MockGuild())
self.author = kwargs.get('author', MockMember())
@@ -462,8 +376,7 @@ class MockAttachment(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.Attachment` instances. For
more information, see the `MockGuild` docstring.
"""
- def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=attachment_instance, **kwargs)
+ spec_set = attachment_instance
class MockMessage(CustomMockMixin, unittest.mock.MagicMock):
@@ -473,10 +386,11 @@ class MockMessage(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.Message` instances. For more
information, see the `MockGuild` docstring.
"""
+ spec_set = message_instance
def __init__(self, **kwargs) -> None:
default_kwargs = {'attachments': []}
- super().__init__(spec_set=message_instance, **collections.ChainMap(kwargs, default_kwargs))
+ super().__init__(**collections.ChainMap(kwargs, default_kwargs))
self.author = kwargs.get('author', MockMember())
self.channel = kwargs.get('channel', MockTextChannel())
@@ -492,9 +406,10 @@ class MockEmoji(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.Emoji` instances. For more
information, see the `MockGuild` docstring.
"""
+ spec_set = emoji_instance
def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=emoji_instance, **kwargs)
+ super().__init__(**kwargs)
self.guild = kwargs.get('guild', MockGuild())
@@ -508,9 +423,7 @@ class MockPartialEmoji(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.PartialEmoji` instances. For
more information, see the `MockGuild` docstring.
"""
-
- def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=partial_emoji_instance, **kwargs)
+ spec_set = partial_emoji_instance
reaction_instance = discord.Reaction(message=MockMessage(), data={'me': True}, emoji=MockEmoji())
@@ -523,12 +436,18 @@ class MockReaction(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.Reaction` instances. For
more information, see the `MockGuild` docstring.
"""
+ spec_set = reaction_instance
def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=reaction_instance, **kwargs)
+ _users = kwargs.pop("users", [])
+ super().__init__(**kwargs)
self.emoji = kwargs.get('emoji', MockEmoji())
self.message = kwargs.get('message', MockMessage())
- self.users = AsyncIteratorMock(kwargs.get('users', []))
+
+ user_iterator = unittest.mock.AsyncMock()
+ user_iterator.__aiter__.return_value = _users
+ self.users.return_value = user_iterator
+
self.__str__.return_value = str(self.emoji)
@@ -542,13 +461,5 @@ class MockAsyncWebhook(CustomMockMixin, unittest.mock.MagicMock):
Instances of this class will follow the specifications of `discord.Webhook` instances. For
more information, see the `MockGuild` docstring.
"""
-
- def __init__(self, **kwargs) -> None:
- super().__init__(spec_set=webhook_instance, **kwargs)
-
- # Because Webhooks can also use a synchronous "WebhookAdapter", the methods are not defined
- # as coroutines. That's why we need to set the methods manually.
- self.send = AsyncMock()
- self.edit = AsyncMock()
- self.delete = AsyncMock()
- self.execute = AsyncMock()
+ spec_set = webhook_instance
+ additional_spec_asyncs = ("send", "edit", "delete", "execute")
diff --git a/tests/test_base.py b/tests/test_base.py
index a16e2af8f..a7db4bf3e 100644
--- a/tests/test_base.py
+++ b/tests/test_base.py
@@ -3,7 +3,11 @@ import unittest
import unittest.mock
-from tests.base import LoggingTestCase, _CaptureLogHandler
+from tests.base import LoggingTestsMixin, _CaptureLogHandler
+
+
+class LoggingTestCase(LoggingTestsMixin, unittest.TestCase):
+ pass
class LoggingTestCaseTests(unittest.TestCase):
@@ -18,24 +22,14 @@ class LoggingTestCaseTests(unittest.TestCase):
try:
with LoggingTestCase.assertNotLogs(self, level=logging.DEBUG):
pass
- except AssertionError:
+ except AssertionError: # pragma: no cover
self.fail("`self.assertNotLogs` raised an AssertionError when it should not!")
- @unittest.mock.patch("tests.base.LoggingTestCase.assertNotLogs")
- def test_the_test_function_assert_not_logs_does_not_raise_with_no_logs(self, assertNotLogs):
- """Test if test_assert_not_logs_does_not_raise_with_no_logs captures exception correctly."""
- assertNotLogs.return_value = iter([None])
- assertNotLogs.side_effect = AssertionError
-
- message = "`self.assertNotLogs` raised an AssertionError when it should not!"
- with self.assertRaises(AssertionError, msg=message):
- self.test_assert_not_logs_does_not_raise_with_no_logs()
-
def test_assert_not_logs_raises_correct_assertion_error_when_logs_are_emitted(self):
"""Test if LoggingTestCase.assertNotLogs raises AssertionError when logs were emitted."""
msg_regex = (
r"1 logs of DEBUG or higher were triggered on root:\n"
- r'<LogRecord: tests\.test_base, [\d]+, .+/tests/test_base\.py, [\d]+, "Log!">'
+ r'<LogRecord: tests\.test_base, [\d]+, .+[/\\]tests[/\\]test_base\.py, [\d]+, "Log!">'
)
with self.assertRaisesRegex(AssertionError, msg_regex):
with LoggingTestCase.assertNotLogs(self, level=logging.DEBUG):
diff --git a/tests/test_helpers.py b/tests/test_helpers.py
index 7894e104a..81285e009 100644
--- a/tests/test_helpers.py
+++ b/tests/test_helpers.py
@@ -1,5 +1,4 @@
import asyncio
-import inspect
import unittest
import unittest.mock
@@ -214,6 +213,11 @@ class DiscordMocksTests(unittest.TestCase):
with self.assertRaises(RuntimeError, msg="cannot reuse already awaited coroutine"):
asyncio.run(coroutine_object)
+ def test_user_mock_uses_explicitly_passed_mention_attribute(self):
+ """MockUser should use an explicitly passed value for user.mention."""
+ user = helpers.MockUser(mention="hello")
+ self.assertEqual(user.mention, "hello")
+
class MockObjectTests(unittest.TestCase):
"""Tests the mock objects and mixins we've defined."""
@@ -341,65 +345,10 @@ class MockObjectTests(unittest.TestCase):
attribute = getattr(mock, valid_attribute)
self.assertTrue(isinstance(attribute, mock_type.child_mock_type))
- def test_extract_coroutine_methods_from_spec_instance_should_extract_all_and_only_coroutines(self):
- """Test if all coroutine functions are extracted, but not regular methods or attributes."""
- class CoroutineDonor:
- def __init__(self):
- self.some_attribute = 'alpha'
-
- async def first_coroutine():
- """This coroutine function should be extracted."""
-
- async def second_coroutine():
- """This coroutine function should be extracted."""
-
- def regular_method():
- """This regular function should not be extracted."""
-
- class Receiver:
+ def test_custom_mock_mixin_mocks_async_magic_methods_with_async_mock(self):
+ """The CustomMockMixin should mock async magic methods with an AsyncMock."""
+ class MyMock(helpers.CustomMockMixin, unittest.mock.MagicMock):
pass
- donor = CoroutineDonor()
- receiver = Receiver()
-
- helpers.CustomMockMixin._extract_coroutine_methods_from_spec_instance(receiver, donor)
-
- self.assertIsInstance(receiver.first_coroutine, helpers.AsyncMock)
- self.assertIsInstance(receiver.second_coroutine, helpers.AsyncMock)
- self.assertFalse(hasattr(receiver, 'regular_method'))
- self.assertFalse(hasattr(receiver, 'some_attribute'))
-
- @unittest.mock.patch("builtins.super", new=unittest.mock.MagicMock())
- @unittest.mock.patch("tests.helpers.CustomMockMixin._extract_coroutine_methods_from_spec_instance")
- def test_custom_mock_mixin_init_with_spec(self, extract_method_mock):
- """Test if CustomMockMixin correctly passes on spec/kwargs and calls the extraction method."""
- spec_set = "pydis"
-
- helpers.CustomMockMixin(spec_set=spec_set)
-
- extract_method_mock.assert_called_once_with(spec_set)
-
- @unittest.mock.patch("builtins.super", new=unittest.mock.MagicMock())
- @unittest.mock.patch("tests.helpers.CustomMockMixin._extract_coroutine_methods_from_spec_instance")
- def test_custom_mock_mixin_init_without_spec(self, extract_method_mock):
- """Test if CustomMockMixin correctly passes on spec/kwargs and calls the extraction method."""
- helpers.CustomMockMixin()
-
- extract_method_mock.assert_not_called()
-
- def test_async_mock_provides_coroutine_for_dunder_call(self):
- """Test if AsyncMock objects have a coroutine for their __call__ method."""
- async_mock = helpers.AsyncMock()
- self.assertTrue(inspect.iscoroutinefunction(async_mock.__call__))
-
- coroutine = async_mock()
- self.assertTrue(inspect.iscoroutine(coroutine))
- self.assertIsNotNone(asyncio.run(coroutine))
-
- def test_async_test_decorator_allows_synchronous_call_to_async_def(self):
- """Test if the `async_test` decorator allows an `async def` to be called synchronously."""
- @helpers.async_test
- async def kosayoda():
- return "return value"
-
- self.assertEqual(kosayoda(), "return value")
+ mock = MyMock()
+ self.assertIsInstance(mock.__aenter__, unittest.mock.AsyncMock)
diff --git a/tests/utils/test_time.py b/tests/utils/test_time.py
deleted file mode 100644
index 4baa6395c..000000000
--- a/tests/utils/test_time.py
+++ /dev/null
@@ -1,62 +0,0 @@
-import asyncio
-from datetime import datetime, timezone
-from unittest.mock import patch
-
-import pytest
-from dateutil.relativedelta import relativedelta
-
-from bot.utils import time
-from tests.helpers import AsyncMock
-
-
- ('delta', 'precision', 'max_units', 'expected'),
- (
- (relativedelta(days=2), 'seconds', 1, '2 days'),
- (relativedelta(days=2, hours=2), 'seconds', 2, '2 days and 2 hours'),
- (relativedelta(days=2, hours=2), 'seconds', 1, '2 days'),
- (relativedelta(days=2, hours=2), 'days', 2, '2 days'),
-
- # Does not abort for unknown units, as the unit name is checked
- # against the attribute of the relativedelta instance.
- (relativedelta(days=2, hours=2), 'elephants', 2, '2 days and 2 hours'),
-
- # Very high maximum units, but it only ever iterates over
- # each value the relativedelta might have.
- (relativedelta(days=2, hours=2), 'hours', 20, '2 days and 2 hours'),
- )
-)
-def test_humanize_delta(
- delta: relativedelta,
- precision: str,
- max_units: int,
- expected: str
-):
- assert time.humanize_delta(delta, precision, max_units) == expected
-
-
[email protected]('max_units', (-1, 0))
-def test_humanize_delta_raises_for_invalid_max_units(max_units: int):
- with pytest.raises(ValueError, match='max_units must be positive'):
- time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units)
-
-
- ('stamp', 'expected'),
- (
- ('Sun, 15 Sep 2019 12:00:00 GMT', datetime(2019, 9, 15, 12, 0, 0, tzinfo=timezone.utc)),
- )
-)
-def test_parse_rfc1123(stamp: str, expected: str):
- assert time.parse_rfc1123(stamp) == expected
-
-
-@patch('asyncio.sleep', new_callable=AsyncMock)
-def test_wait_until(sleep_patch):
- start = datetime(2019, 1, 1, 0, 0)
- then = datetime(2019, 1, 1, 0, 10)
-
- # No return value
- assert asyncio.run(time.wait_until(then, start)) is None
-
- sleep_patch.assert_called_once_with(10 * 60)