diff options
80 files changed, 4239 insertions, 2207 deletions
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 860357868..876d32b15 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,10 +1,26 @@ repos: -- repo: local + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.5.0 hooks: - - id: flake8 + - id: check-merge-conflict + - id: check-toml + - id: check-yaml + args: [--unsafe] # Required due to custom constructors (e.g. !ENV) + - id: end-of-file-fixer + - id: mixed-line-ending + args: [--fix=lf] + - id: trailing-whitespace + args: [--markdown-linebreak-ext=md] + - repo: https://github.com/pre-commit/pygrep-hooks + rev: v1.5.1 + hooks: + - id: python-check-blanket-noqa + - repo: local + hooks: + - id: flake8 name: Flake8 description: This hook runs flake8 within our project's pipenv environment. - entry: pipenv run lint + entry: pipenv run flake8 language: python types: [python] - require_serial: true
\ No newline at end of file + require_serial: true diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 39f76c7b4..61d11f844 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -43,7 +43,7 @@ To provide a standalone development environment for this project, docker compose When pulling down changes from GitHub, remember to sync your environment using `pipenv sync --dev` to ensure you're using the most up-to-date versions the project's dependencies. ### Type Hinting -[PEP 484](https://www.python.org/dev/peps/pep-0484/) formally specifies type hints for Python functions, added to the Python Standard Library in version 3.5. Type hints are recognized by most modern code editing tools and provide useful insight into both the input and output types of a function, preventing the user from having to go through the codebase to determine these types. +[PEP 484](https://www.python.org/dev/peps/pep-0484/) formally specifies type hints for Python functions, added to the Python Standard Library in version 3.5. Type hints are recognized by most modern code editing tools and provide useful insight into both the input and output types of a function, preventing the user from having to go through the codebase to determine these types. For example: diff --git a/Dockerfile b/Dockerfile index 271c25050..06a538b2a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.7-slim +FROM python:3.8-slim # Set pip to have cleaner logs and no saved cache ENV PIP_NO_CACHE_DIR=false \ @@ -9,12 +9,15 @@ ENV PIP_NO_CACHE_DIR=false \ # Install pipenv RUN pip install -U pipenv -# Copy project files into working directory +# Create the working directory WORKDIR /bot -COPY . . # Install project dependencies +COPY Pipfile* ./ RUN pipenv install --system --deploy +# Copy the source code in last to optimize rebuilding the image +COPY . . + ENTRYPOINT ["python3"] CMD ["-m", "bot"] @@ -4,9 +4,8 @@ verify_ssl = true name = "pypi" [packages] -discord-py = "~=1.3.1" +discord-py = "~=1.3.2" aiodns = "~=2.0" -logmatic-python = "~=0.1" aiohttp = "~=3.5" sphinx = "~=2.2" markdownify = "~=0.4" @@ -17,30 +16,33 @@ aio-pika = "~=6.1" python-dateutil = "~=2.8" deepdiff = "~=4.0" requests = "~=2.22" -more_itertools = "~=7.2" -urllib3 = ">=1.24.2,<1.25" +more_itertools = "~=8.2" +sentry-sdk = "~=0.14" +coloredlogs = "~=14.0" +colorama = {version = "~=0.4.3", sys_platform = "== 'win32'"} [dev-packages] -coverage = "~=4.5" +coverage = "~=5.0" flake8 = "~=3.7" -flake8-annotations = "~=1.1" -flake8-bugbear = "~=19.8" +flake8-annotations = "~=2.0" +flake8-bugbear = "~=20.1" flake8-docstrings = "~=1.4" flake8-import-order = "~=0.18" flake8-string-format = "~=0.2" -flake8-tidy-imports = "~=2.0" +flake8-tidy-imports = "~=4.0" flake8-todo = "~=0.7" -pre-commit = "~=1.18" +pep8-naming = "~=0.9" +pre-commit = "~=2.1" safety = "~=1.8" -unittest-xml-reporting = "~=2.5" +unittest-xml-reporting = "~=3.0" dodgy = "~=0.1" [requires] -python_version = "3.7" +python_version = "3.8" [scripts] start = "python -m bot" -lint = "python -m flake8" +lint = "pre-commit run --all-files" precommit = "pre-commit install" build = "docker build -t pythondiscord/bot:latest -f Dockerfile ." push = "docker push pythondiscord/bot:latest" diff --git a/Pipfile.lock b/Pipfile.lock index bf8ff47e9..348456f2c 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,11 +1,11 @@ { "_meta": { "hash": { - "sha256": "0a0354a8cbd25b19c61b68f928493a445e737dc6447c97f4c4b52fbf72d887ac" + "sha256": "b8b38e84230bdc37f8c8955e8dddc442183a2e23c4dfc6ed37c522644aecdeea" }, "pipfile-spec": 6, "requires": { - "python_version": "3.7" + "python_version": "3.8" }, "sources": [ { @@ -18,11 +18,11 @@ "default": { "aio-pika": { "hashes": [ - "sha256:a5837277e53755078db3a9e8c45bbca605c8ba9ecba7a02d74a7a1779f444723", - "sha256:fa32e33b4b7d0804dcf439ae6ff24d2f0a83d1ba280ee9f555e647d71d394ff5" + "sha256:0332bc13abbd8923dac657b331716778c55ea0a32ac0951306ce85edafcc916c", + "sha256:39770d8bc7e9059e28622d599e2ac9ebc16a7198b33d1743c1a496ca3b0f8170" ], "index": "pypi", - "version": "==6.4.1" + "version": "==6.5.3" }, "aiodns": { "hashes": [ @@ -52,10 +52,10 @@ }, "aiormq": { "hashes": [ - "sha256:8c215a970133ab5ee7c478decac55b209af7731050f52d11439fe910fa0f9e9d", - "sha256:9210f3389200aee7d8067f6435f4a9eff2d3a30b88beb5eaae406ccc11c0fc01" + "sha256:286e0b0772075580466e45f98f051b9728a9316b9c36f0c14c7bc1409be375b0", + "sha256:7ed7d6df6b57af7f8bce7d1ebcbdfc32b676192e46703e81e9e217316e56b5bd" ], - "version": "==3.2.0" + "version": "==3.2.1" }, "alabaster": { "hashes": [ @@ -140,6 +140,23 @@ ], "version": "==3.0.4" }, + "colorama": { + "hashes": [ + "sha256:7d73d2a99753107a36ac6b455ee49046802e59d9d076ef8e47b61499fa29afff", + "sha256:e96da0d330793e2cb9485e9ddfd918d456036c7149416295932478192f4436a1" + ], + "index": "pypi", + "markers": "sys_platform == 'win32'", + "version": "==0.4.3" + }, + "coloredlogs": { + "hashes": [ + "sha256:346f58aad6afd48444c2468618623638dadab76e4e70d5e10822676f2d32226a", + "sha256:a1fab193d2053aa6c0a97608c4342d031f1f93a3d1218432c59322441d31a505" + ], + "index": "pypi", + "version": "==14.0" + }, "deepdiff": { "hashes": [ "sha256:b3fa588d1eac7fa318ec1fb4f2004568e04cb120a1989feda8e5e7164bcbf07a", @@ -150,10 +167,10 @@ }, "discord-py": { "hashes": [ - "sha256:8bfe5628d31771744000f19135c386c74ac337479d7282c26cc1627b9d31f360" + "sha256:7424be26b07b37ecad4404d9383d685995a0e0b3df3f9c645bdd3a4d977b83b4" ], "index": "pypi", - "version": "==1.3.1" + "version": "==1.3.2" }, "docutils": { "hashes": [ @@ -164,18 +181,25 @@ }, "fuzzywuzzy": { "hashes": [ - "sha256:5ac7c0b3f4658d2743aa17da53a55598144edbc5bee3c6863840636e6926f254", - "sha256:6f49de47db00e1c71d40ad16da42284ac357936fa9b66bea1df63fed07122d62" + "sha256:45016e92264780e58972dca1b3d939ac864b78437422beecebb3095f8efd00e8", + "sha256:928244b28db720d1e0ee7587acf660ea49d7e4c632569cad4f1cd7e68a5f0993" ], "index": "pypi", - "version": "==0.17.0" + "version": "==0.18.0" + }, + "humanfriendly": { + "hashes": [ + "sha256:2f79aaa2965c0fc3d79452e64ec2c7601d70d67e51ea2e99cb40afe3fe2824c5", + "sha256:6990c0af4b72f50ddf302900eb982edf199247e621e06d80d71b00b1a1574214" + ], + "version": "==8.0" }, "idna": { "hashes": [ - "sha256:c357b3f628cf53ae2c4c05627ecc484553142ca23264e593d327bcde5e9c3407", - "sha256:ea8b7f6188e6fa117537c3df7da9fc686d485087abf6ac197f9c46432f7e4a3c" + "sha256:7588d1c14ae4c77d74036e8c22ff447b26d0fde8f007354fd48a7814db15b7cb", + "sha256:a068a21ceac8a4d63dbfd964670474107f541babbd2250d61922f029858365fa" ], - "version": "==2.8" + "version": "==2.9" }, "imagesize": { "hashes": [ @@ -191,13 +215,6 @@ ], "version": "==2.11.1" }, - "logmatic-python": { - "hashes": [ - "sha256:0c15ac9f5faa6a60059b28910db642c3dc7722948c3cc940923f8c9039604342" - ], - "index": "pypi", - "version": "==0.1.7" - }, "lxml": { "hashes": [ "sha256:06d4e0bbb1d62e38ae6118406d7cdb4693a3fa34ee3762238bcb96c9e36a93cd", @@ -278,33 +295,33 @@ }, "more-itertools": { "hashes": [ - "sha256:409cd48d4db7052af495b09dec721011634af3753ae1ef92d2b32f73a745f832", - "sha256:92b8c4b06dac4f0611c0729b2f2ede52b2e1bac1ab48f089c7ddc12e26bb60c4" + "sha256:5dd8bcf33e5f9513ffa06d5ad33d78f31e1931ac9a18f33d37e77a180d393a7c", + "sha256:b1ddb932186d8a6ac451e1d95844b382f55e12686d51ca0c68b6f61f2ab7a507" ], "index": "pypi", - "version": "==7.2.0" + "version": "==8.2.0" }, "multidict": { "hashes": [ - "sha256:13f3ebdb5693944f52faa7b2065b751cb7e578b8dd0a5bb8e4ab05ad0188b85e", - "sha256:26502cefa86d79b86752e96639352c7247846515c864d7c2eb85d036752b643c", - "sha256:4fba5204d32d5c52439f88437d33ad14b5f228e25072a192453f658bddfe45a7", - "sha256:527124ef435f39a37b279653ad0238ff606b58328ca7989a6df372fd75d7fe26", - "sha256:5414f388ffd78c57e77bd253cf829373721f450613de53dc85a08e34d806e8eb", - "sha256:5eee66f882ab35674944dfa0d28b57fa51e160b4dce0ce19e47f495fdae70703", - "sha256:63810343ea07f5cd86ba66ab66706243a6f5af075eea50c01e39b4ad6bc3c57a", - "sha256:6bd10adf9f0d6a98ccc792ab6f83d18674775986ba9bacd376b643fe35633357", - "sha256:83c6ddf0add57c6b8a7de0bc7e2d656be3eefeff7c922af9a9aae7e49f225625", - "sha256:93166e0f5379cf6cd29746989f8a594fa7204dcae2e9335ddba39c870a287e1c", - "sha256:9a7b115ee0b9b92d10ebc246811d8f55d0c57e82dbb6a26b23c9a9a6ad40ce0c", - "sha256:a38baa3046cce174a07a59952c9f876ae8875ef3559709639c17fdf21f7b30dd", - "sha256:a6d219f49821f4b2c85c6d426346a5d84dab6daa6f85ca3da6c00ed05b54022d", - "sha256:a8ed33e8f9b67e3b592c56567135bb42e7e0e97417a4b6a771e60898dfd5182b", - "sha256:d7d428488c67b09b26928950a395e41cc72bb9c3d5abfe9f0521940ee4f796d4", - "sha256:dcfed56aa085b89d644af17442cdc2debaa73388feba4b8026446d168ca8dad7", - "sha256:f29b885e4903bd57a7789f09fe9d60b6475a6c1a4c0eca874d8558f00f9d4b51" - ], - "version": "==4.7.4" + "sha256:317f96bc0950d249e96d8d29ab556d01dd38888fbe68324f46fd834b430169f1", + "sha256:42f56542166040b4474c0c608ed051732033cd821126493cf25b6c276df7dd35", + "sha256:4b7df040fb5fe826d689204f9b544af469593fb3ff3a069a6ad3409f742f5928", + "sha256:544fae9261232a97102e27a926019100a9db75bec7b37feedd74b3aa82f29969", + "sha256:620b37c3fea181dab09267cd5a84b0f23fa043beb8bc50d8474dd9694de1fa6e", + "sha256:6e6fef114741c4d7ca46da8449038ec8b1e880bbe68674c01ceeb1ac8a648e78", + "sha256:7774e9f6c9af3f12f296131453f7b81dabb7ebdb948483362f5afcaac8a826f1", + "sha256:85cb26c38c96f76b7ff38b86c9d560dea10cf3459bb5f4caf72fc1bb932c7136", + "sha256:a326f4240123a2ac66bb163eeba99578e9d63a8654a59f4688a79198f9aa10f8", + "sha256:ae402f43604e3b2bc41e8ea8b8526c7fa7139ed76b0d64fc48e28125925275b2", + "sha256:aee283c49601fa4c13adc64c09c978838a7e812f85377ae130a24d7198c0331e", + "sha256:b51249fdd2923739cd3efc95a3d6c363b67bbf779208e9f37fd5e68540d1a4d4", + "sha256:bb519becc46275c594410c6c28a8a0adc66fe24fef154a9addea54c1adb006f5", + "sha256:c2c37185fb0af79d5c117b8d2764f4321eeb12ba8c141a95d0aa8c2c1d0a11dd", + "sha256:dc561313279f9d05a3d0ffa89cd15ae477528ea37aa9795c4654588a3287a9ab", + "sha256:e439c9a10a95cb32abd708bb8be83b2134fa93790a4fb0535ca36db3dda94d20", + "sha256:fc3b4adc2ee8474cb3cd2a155305d5f8eda0a9c91320f83e55748e1fcb68f8e3" + ], + "version": "==4.7.5" }, "ordered-set": { "hashes": [ @@ -380,6 +397,13 @@ ], "version": "==2.4.6" }, + "pyreadline": { + "hashes": [ + "sha256:4530592fc2e85b25b1a9f79664433da09237c1a270e4d78ea5aa3a2c7229e2d1" + ], + "markers": "sys_platform == 'win32'", + "version": "==2.1" + }, "python-dateutil": { "hashes": [ "sha256:73ebfe9dbf22e832286dafa60473e4cd239f8592f699aa5adaf10050e6e1823c", @@ -388,12 +412,6 @@ "index": "pypi", "version": "==2.8.1" }, - "python-json-logger": { - "hashes": [ - "sha256:b7a31162f2a01965a5efb94453ce69230ed208468b0bbc7fdfc56e6d8df2e281" - ], - "version": "==0.1.11" - }, "pytz": { "hashes": [ "sha256:1c557d7d0e871de1f5ccd5833f60fb2550652da6be2693c1e02300743d21500d", @@ -420,11 +438,19 @@ }, "requests": { "hashes": [ - "sha256:11e007a8a2aa0323f5a921e9e6a2d7e4e67d9877e85773fba9ba6419025cbeb4", - "sha256:9cf5292fcd0f598c671cfc1e0d7d1a7f13bb8085e9a590f48c010551dc6c4b31" + "sha256:43999036bfa82904b6af1d99e4882b560e5e2c68e5c4b0aa03b655f3d7d73fee", + "sha256:b3f43d496c6daba4493e7c431722aeb7dbc6288f52a6e04e7b6023b0247817e6" ], "index": "pypi", - "version": "==2.22.0" + "version": "==2.23.0" + }, + "sentry-sdk": { + "hashes": [ + "sha256:480eee754e60bcae983787a9a13bc8f155a111aef199afaa4f289d6a76aa622a", + "sha256:a920387dc3ee252a66679d0afecd34479fb6fc52c2bc20763793ed69e5b0dcc0" + ], + "index": "pypi", + "version": "==0.14.2" }, "six": { "hashes": [ @@ -442,39 +468,39 @@ }, "soupsieve": { "hashes": [ - "sha256:bdb0d917b03a1369ce964056fc195cfdff8819c40de04695a80bc813c3cfa1f5", - "sha256:e2c1c5dee4a1c36bcb790e0fabd5492d874b8ebd4617622c4f6a731701060dda" + "sha256:e914534802d7ffd233242b785229d5ba0766a7f487385e3f714446a07bf540ae", + "sha256:fcd71e08c0aee99aca1b73f45478549ee7e7fc006d51b37bec9e9def7dc22b69" ], - "version": "==1.9.5" + "version": "==2.0" }, "sphinx": { "hashes": [ - "sha256:298537cb3234578b2d954ff18c5608468229e116a9757af3b831c2b2b4819159", - "sha256:e6e766b74f85f37a5f3e0773a1e1be8db3fcb799deb58ca6d18b70b0b44542a5" + "sha256:776ff8333181138fae52df65be733127539623bb46cc692e7fa0fcfc80d7aa88", + "sha256:ca762da97c3b5107cbf0ab9e11d3ec7ab8d3c31377266fd613b962ed971df709" ], "index": "pypi", - "version": "==2.3.1" + "version": "==2.4.3" }, "sphinxcontrib-applehelp": { "hashes": [ - "sha256:edaa0ab2b2bc74403149cb0209d6775c96de797dfd5b5e2a71981309efab3897", - "sha256:fb8dee85af95e5c30c91f10e7eb3c8967308518e0f7488a2828ef7bc191d0d5d" + "sha256:806111e5e962be97c29ec4c1e7fe277bfd19e9652fb1a4392105b43e01af885a", + "sha256:a072735ec80e7675e3f432fcae8610ecf509c5f1869d17e2eecff44389cdbc58" ], - "version": "==1.0.1" + "version": "==1.0.2" }, "sphinxcontrib-devhelp": { "hashes": [ - "sha256:6c64b077937330a9128a4da74586e8c2130262f014689b4b89e2d08ee7294a34", - "sha256:9512ecb00a2b0821a146736b39f7aeb90759834b07e81e8cc23a9c70bacb9981" + "sha256:8165223f9a335cc1af7ffe1ed31d2871f325254c0423bc0c4c7cd1c1e4734a2e", + "sha256:ff7f1afa7b9642e7060379360a67e9c41e8f3121f2ce9164266f61b9f4b338e4" ], - "version": "==1.0.1" + "version": "==1.0.2" }, "sphinxcontrib-htmlhelp": { "hashes": [ - "sha256:4670f99f8951bd78cd4ad2ab962f798f5618b17675c35c5ac3b2132a14ea8422", - "sha256:d4fd39a65a625c9df86d7fa8a2d9f3cd8299a3a4b15db63b50aac9e161d8eff7" + "sha256:3c0bc24a2c41e340ac37c85ced6dafc879ab485c095b1d65d2461ac2f7cca86f", + "sha256:e8f5bb7e31b2dbb25b9cc435c8ab7a79787ebf7f906155729338f3156d93659b" ], - "version": "==1.0.2" + "version": "==1.0.3" }, "sphinxcontrib-jsmath": { "hashes": [ @@ -485,25 +511,24 @@ }, "sphinxcontrib-qthelp": { "hashes": [ - "sha256:513049b93031beb1f57d4daea74068a4feb77aa5630f856fcff2e50de14e9a20", - "sha256:79465ce11ae5694ff165becda529a600c754f4bc459778778c7017374d4d406f" + "sha256:4c33767ee058b70dba89a6fc5c1892c0d57a54be67ddd3e7875a18d14cba5a72", + "sha256:bd9fc24bcb748a8d51fd4ecaade681350aa63009a347a8c14e637895444dfab6" ], - "version": "==1.0.2" + "version": "==1.0.3" }, "sphinxcontrib-serializinghtml": { "hashes": [ - "sha256:c0efb33f8052c04fd7a26c0a07f1678e8512e0faec19f4aa8f2473a8b81d5227", - "sha256:db6615af393650bf1151a6cd39120c29abaf93cc60db8c48eb2dddbfdc3a9768" + "sha256:eaa0eccc86e982a9b939b2b82d12cc5d013385ba5eadcc7e4fed23f4405f77bc", + "sha256:f242a81d423f59617a8e5cf16f5d4d74e28ee9a66f9e5b637a18082991db5a9a" ], - "version": "==1.1.3" + "version": "==1.1.4" }, "urllib3": { "hashes": [ - "sha256:2393a695cd12afedd0dcb26fe5d50d0cf248e5a66f75dbd89a3d4eb333a61af4", - "sha256:a637e5fae88995b256e3409dc4d52c2e2e0ba32c42a6365fee8bbd2238de3cfb" + "sha256:2f3db8b19923a873b3e5256dc9c2dedfa883e33d87c690d9c7913e1f40673cdc", + "sha256:87716c2d2a7121198ebcb7ce7cccf6ce5e9ba539041cfbaeecfb641dc0bf6acc" ], - "index": "pypi", - "version": "==1.24.3" + "version": "==1.25.8" }, "websockets": { "hashes": [ @@ -556,12 +581,12 @@ } }, "develop": { - "aspy.yaml": { + "appdirs": { "hashes": [ - "sha256:463372c043f70160a9ec950c3f1e4c3a82db5fca01d334b6bc89c7164d744bdc", - "sha256:e7c742382eff2caed61f87a39d13f99109088e5e93f04d76eb8d4b28aa143f45" + "sha256:9e5896d1372858f8dd3344faf4e5014d21849c756c8d5701f78f8a103b372d92", + "sha256:d8b24664561d0d34ddfaec54636d502d7cea6e29c3eaf68f3df6180863e2166e" ], - "version": "==1.3.0" + "version": "==1.4.3" }, "attrs": { "hashes": [ @@ -579,10 +604,10 @@ }, "cfgv": { "hashes": [ - "sha256:edb387943b665bf9c434f717bf630fa78aecd53d5900d2e05da6ad6048553144", - "sha256:fbd93c9ab0a523bf7daec408f3be2ed99a980e20b2d19b50fc184ca6b820d289" + "sha256:1ccf53320421aeeb915275a196e23b3b8ae87dea8ac6698b1638001d4a486d53", + "sha256:c8e8f552ffcc6194f4e18dd4f68d9aef0c0d58ae7e7be8c82bee3c5e9edfa513" ], - "version": "==2.0.1" + "version": "==3.1.0" }, "chardet": { "hashes": [ @@ -600,41 +625,46 @@ }, "coverage": { "hashes": [ - "sha256:08907593569fe59baca0bf152c43f3863201efb6113ecb38ce7e97ce339805a6", - "sha256:0be0f1ed45fc0c185cfd4ecc19a1d6532d72f86a2bac9de7e24541febad72650", - "sha256:141f08ed3c4b1847015e2cd62ec06d35e67a3ac185c26f7635f4406b90afa9c5", - "sha256:19e4df788a0581238e9390c85a7a09af39c7b539b29f25c89209e6c3e371270d", - "sha256:23cc09ed395b03424d1ae30dcc292615c1372bfba7141eb85e11e50efaa6b351", - "sha256:245388cda02af78276b479f299bbf3783ef0a6a6273037d7c60dc73b8d8d7755", - "sha256:331cb5115673a20fb131dadd22f5bcaf7677ef758741312bee4937d71a14b2ef", - "sha256:386e2e4090f0bc5df274e720105c342263423e77ee8826002dcffe0c9533dbca", - "sha256:3a794ce50daee01c74a494919d5ebdc23d58873747fa0e288318728533a3e1ca", - "sha256:60851187677b24c6085248f0a0b9b98d49cba7ecc7ec60ba6b9d2e5574ac1ee9", - "sha256:63a9a5fc43b58735f65ed63d2cf43508f462dc49857da70b8980ad78d41d52fc", - "sha256:6b62544bb68106e3f00b21c8930e83e584fdca005d4fffd29bb39fb3ffa03cb5", - "sha256:6ba744056423ef8d450cf627289166da65903885272055fb4b5e113137cfa14f", - "sha256:7494b0b0274c5072bddbfd5b4a6c6f18fbbe1ab1d22a41e99cd2d00c8f96ecfe", - "sha256:826f32b9547c8091679ff292a82aca9c7b9650f9fda3e2ca6bf2ac905b7ce888", - "sha256:93715dffbcd0678057f947f496484e906bf9509f5c1c38fc9ba3922893cda5f5", - "sha256:9a334d6c83dfeadae576b4d633a71620d40d1c379129d587faa42ee3e2a85cce", - "sha256:af7ed8a8aa6957aac47b4268631fa1df984643f07ef00acd374e456364b373f5", - "sha256:bf0a7aed7f5521c7ca67febd57db473af4762b9622254291fbcbb8cd0ba5e33e", - "sha256:bf1ef9eb901113a9805287e090452c05547578eaab1b62e4ad456fcc049a9b7e", - "sha256:c0afd27bc0e307a1ffc04ca5ec010a290e49e3afbe841c5cafc5c5a80ecd81c9", - "sha256:dd579709a87092c6dbee09d1b7cfa81831040705ffa12a1b248935274aee0437", - "sha256:df6712284b2e44a065097846488f66840445eb987eb81b3cc6e4149e7b6982e1", - "sha256:e07d9f1a23e9e93ab5c62902833bf3e4b1f65502927379148b6622686223125c", - "sha256:e2ede7c1d45e65e209d6093b762e98e8318ddeff95317d07a27a2140b80cfd24", - "sha256:e4ef9c164eb55123c62411f5936b5c2e521b12356037b6e1c2617cef45523d47", - "sha256:eca2b7343524e7ba246cab8ff00cab47a2d6d54ada3b02772e908a45675722e2", - "sha256:eee64c616adeff7db37cc37da4180a3a5b6177f5c46b187894e633f088fb5b28", - "sha256:ef824cad1f980d27f26166f86856efe11eff9912c4fed97d3804820d43fa550c", - "sha256:efc89291bd5a08855829a3c522df16d856455297cf35ae827a37edac45f466a7", - "sha256:fa964bae817babece5aa2e8c1af841bebb6d0b9add8e637548809d040443fee0", - "sha256:ff37757e068ae606659c28c3bd0d923f9d29a85de79bf25b2b34b148473b5025" - ], - "index": "pypi", - "version": "==4.5.4" + "sha256:15cf13a6896048d6d947bf7d222f36e4809ab926894beb748fc9caa14605d9c3", + "sha256:1daa3eceed220f9fdb80d5ff950dd95112cd27f70d004c7918ca6dfc6c47054c", + "sha256:1e44a022500d944d42f94df76727ba3fc0a5c0b672c358b61067abb88caee7a0", + "sha256:25dbf1110d70bab68a74b4b9d74f30e99b177cde3388e07cc7272f2168bd1477", + "sha256:3230d1003eec018ad4a472d254991e34241e0bbd513e97a29727c7c2f637bd2a", + "sha256:3dbb72eaeea5763676a1a1efd9b427a048c97c39ed92e13336e726117d0b72bf", + "sha256:5012d3b8d5a500834783689a5d2292fe06ec75dc86ee1ccdad04b6f5bf231691", + "sha256:51bc7710b13a2ae0c726f69756cf7ffd4362f4ac36546e243136187cfcc8aa73", + "sha256:527b4f316e6bf7755082a783726da20671a0cc388b786a64417780b90565b987", + "sha256:722e4557c8039aad9592c6a4213db75da08c2cd9945320220634f637251c3894", + "sha256:76e2057e8ffba5472fd28a3a010431fd9e928885ff480cb278877c6e9943cc2e", + "sha256:77afca04240c40450c331fa796b3eab6f1e15c5ecf8bf2b8bee9706cd5452fef", + "sha256:7afad9835e7a651d3551eab18cbc0fdb888f0a6136169fbef0662d9cdc9987cf", + "sha256:9bea19ac2f08672636350f203db89382121c9c2ade85d945953ef3c8cf9d2a68", + "sha256:a8b8ac7876bc3598e43e2603f772d2353d9931709345ad6c1149009fd1bc81b8", + "sha256:b0840b45187699affd4c6588286d429cd79a99d509fe3de0f209594669bb0954", + "sha256:b26aaf69713e5674efbde4d728fb7124e429c9466aeaf5f4a7e9e699b12c9fe2", + "sha256:b63dd43f455ba878e5e9f80ba4f748c0a2156dde6e0e6e690310e24d6e8caf40", + "sha256:be18f4ae5a9e46edae3f329de2191747966a34a3d93046dbdf897319923923bc", + "sha256:c312e57847db2526bc92b9bfa78266bfbaabac3fdcd751df4d062cd4c23e46dc", + "sha256:c60097190fe9dc2b329a0eb03393e2e0829156a589bd732e70794c0dd804258e", + "sha256:c62a2143e1313944bf4a5ab34fd3b4be15367a02e9478b0ce800cb510e3bbb9d", + "sha256:cc1109f54a14d940b8512ee9f1c3975c181bbb200306c6d8b87d93376538782f", + "sha256:cd60f507c125ac0ad83f05803063bed27e50fa903b9c2cfee3f8a6867ca600fc", + "sha256:d513cc3db248e566e07a0da99c230aca3556d9b09ed02f420664e2da97eac301", + "sha256:d649dc0bcace6fcdb446ae02b98798a856593b19b637c1b9af8edadf2b150bea", + "sha256:d7008a6796095a79544f4da1ee49418901961c97ca9e9d44904205ff7d6aa8cb", + "sha256:da93027835164b8223e8e5af2cf902a4c80ed93cb0909417234f4a9df3bcd9af", + "sha256:e69215621707119c6baf99bda014a45b999d37602cb7043d943c76a59b05bf52", + "sha256:ea9525e0fef2de9208250d6c5aeeee0138921057cd67fcef90fbed49c4d62d37", + "sha256:fca1669d464f0c9831fd10be2eef6b86f5ebd76c724d1e0706ebdff86bb4adf0" + ], + "index": "pypi", + "version": "==5.0.3" + }, + "distlib": { + "hashes": [ + "sha256:2e166e231a26b36d6dfe35a48c4464346620f8645ed0ace01ee31822b288de21" + ], + "version": "==0.3.0" }, "dodgy": { "hashes": [ @@ -658,6 +688,13 @@ ], "version": "==0.3" }, + "filelock": { + "hashes": [ + "sha256:18d82244ee114f543149c66a6e0c14e9c4f8a1044b5cdaadd0f82159d6a6ff59", + "sha256:929b7d63ec5b7d6b71b0fa5ac14e030b3f70b75747cef1b10da9b879fef15836" + ], + "version": "==3.0.12" + }, "flake8": { "hashes": [ "sha256:45681a117ecc81e870cbf1262835ae4af5e7a8b08e40b944a8a6e6b895914cfb", @@ -668,19 +705,19 @@ }, "flake8-annotations": { "hashes": [ - "sha256:05b85538014c850a86dce7374bb6621c64481c24e35e8e90af1315f4d7a3dbaa", - "sha256:43e5233a76fda002b91a54a7cc4510f099c4bfd6279502ec70164016250eebd1" + "sha256:a38b44d01abd480586a92a02a2b0a36231ec42dcc5e114de78fa5db016d8d3f9", + "sha256:d5b0e8704e4e7728b352fa1464e23539ff2341ba11cc153b536fa2cf921ee659" ], "index": "pypi", - "version": "==1.1.3" + "version": "==2.0.1" }, "flake8-bugbear": { "hashes": [ - "sha256:d8c466ea79d5020cb20bf9f11cf349026e09517a42264f313d3f6fddb83e0571", - "sha256:ded4d282778969b5ab5530ceba7aa1a9f1b86fa7618fc96a19a1d512331640f8" + "sha256:a3ddc03ec28ba2296fc6f89444d1c946a6b76460f859795b35b77d4920a51b63", + "sha256:bd02e4b009fb153fe6072c31c52aeab5b133d508095befb2ffcf3b41c4823162" ], "index": "pypi", - "version": "==19.8.0" + "version": "==20.1.4" }, "flake8-docstrings": { "hashes": [ @@ -698,21 +735,28 @@ "index": "pypi", "version": "==0.18.1" }, + "flake8-polyfill": { + "hashes": [ + "sha256:12be6a34ee3ab795b19ca73505e7b55826d5f6ad7230d31b18e106400169b9e9", + "sha256:e44b087597f6da52ec6393a709e7108b2905317d0c0b744cdca6208e670d8eda" + ], + "version": "==1.0.2" + }, "flake8-string-format": { "hashes": [ - "sha256:68ea72a1a5b75e7018cae44d14f32473c798cf73d75cbaed86c6a9a907b770b2", - "sha256:774d56103d9242ed968897455ef49b7d6de272000cfa83de5814273a868832f1" + "sha256:65f3da786a1461ef77fca3780b314edb2853c377f2e35069723348c8917deaa2", + "sha256:812ff431f10576a74c89be4e85b8e075a705be39bc40c4b4278b5b13e2afa9af" ], "index": "pypi", - "version": "==0.2.3" + "version": "==0.3.0" }, "flake8-tidy-imports": { "hashes": [ - "sha256:1c476aabc6e8db26dc75278464a3a392dba0ea80562777c5f13fd5cdf2646154", - "sha256:b3f5b96affd0f57cacb6621ed28286ce67edaca807757b51227043ebf7b136a1" + "sha256:8aa34384b45137d4cf33f5818b8e7897dc903b1d1e10a503fa7dd193a9a710ba", + "sha256:b26461561bcc80e8012e46846630ecf0aaa59314f362a94cb7800dfdb32fa413" ], "index": "pypi", - "version": "==2.0.0" + "version": "==4.0.0" }, "flake8-todo": { "hashes": [ @@ -730,18 +774,10 @@ }, "idna": { "hashes": [ - "sha256:c357b3f628cf53ae2c4c05627ecc484553142ca23264e593d327bcde5e9c3407", - "sha256:ea8b7f6188e6fa117537c3df7da9fc686d485087abf6ac197f9c46432f7e4a3c" + "sha256:7588d1c14ae4c77d74036e8c22ff447b26d0fde8f007354fd48a7814db15b7cb", + "sha256:a068a21ceac8a4d63dbfd964670474107f541babbd2250d61922f029858365fa" ], - "version": "==2.8" - }, - "importlib-metadata": { - "hashes": [ - "sha256:06f5b3a99029c7134207dd882428a66992a9de2bef7c2b699b5641f9886c3302", - "sha256:b97607a1a18a5100839aec1dc26a1ea17ee0d93b20b0f008d80a5a050afb200b" - ], - "markers": "python_version < '3.8'", - "version": "==1.5.0" + "version": "==2.9" }, "mccabe": { "hashes": [ @@ -763,13 +799,21 @@ ], "version": "==20.1" }, + "pep8-naming": { + "hashes": [ + "sha256:45f330db8fcfb0fba57458c77385e288e7a3be1d01e8ea4268263ef677ceea5f", + "sha256:a33d38177056321a167decd6ba70b890856ba5025f0a8eca6a3eda607da93caf" + ], + "index": "pypi", + "version": "==0.9.1" + }, "pre-commit": { "hashes": [ - "sha256:8f48d8637bdae6fa70cc97db9c1dd5aa7c5c8bf71968932a380628c25978b850", - "sha256:f92a359477f3252452ae2e8d3029de77aec59415c16ae4189bcfba40b757e029" + "sha256:09ebe467f43ce24377f8c2f200fe3cd2570d328eb2ce0568c8e96ce19da45fa6", + "sha256:f8d555e31e2051892c7f7b3ad9f620bd2c09271d87e9eedb2ad831737d6211eb" ], "index": "pypi", - "version": "==1.21.0" + "version": "==2.1.1" }, "pycodestyle": { "hashes": [ @@ -818,11 +862,11 @@ }, "requests": { "hashes": [ - "sha256:11e007a8a2aa0323f5a921e9e6a2d7e4e67d9877e85773fba9ba6419025cbeb4", - "sha256:9cf5292fcd0f598c671cfc1e0d7d1a7f13bb8085e9a590f48c010551dc6c4b31" + "sha256:43999036bfa82904b6af1d99e4882b560e5e2c68e5c4b0aa03b655f3d7d73fee", + "sha256:b3f43d496c6daba4493e7c431722aeb7dbc6288f52a6e04e7b6023b0247817e6" ], "index": "pypi", - "version": "==2.22.0" + "version": "==2.23.0" }, "safety": { "hashes": [ @@ -853,62 +897,27 @@ ], "version": "==0.10.0" }, - "typed-ast": { - "hashes": [ - "sha256:0666aa36131496aed8f7be0410ff974562ab7eeac11ef351def9ea6fa28f6355", - "sha256:0c2c07682d61a629b68433afb159376e24e5b2fd4641d35424e462169c0a7919", - "sha256:249862707802d40f7f29f6e1aad8d84b5aa9e44552d2cc17384b209f091276aa", - "sha256:24995c843eb0ad11a4527b026b4dde3da70e1f2d8806c99b7b4a7cf491612652", - "sha256:269151951236b0f9a6f04015a9004084a5ab0d5f19b57de779f908621e7d8b75", - "sha256:4083861b0aa07990b619bd7ddc365eb7fa4b817e99cf5f8d9cf21a42780f6e01", - "sha256:498b0f36cc7054c1fead3d7fc59d2150f4d5c6c56ba7fb150c013fbc683a8d2d", - "sha256:4e3e5da80ccbebfff202a67bf900d081906c358ccc3d5e3c8aea42fdfdfd51c1", - "sha256:6daac9731f172c2a22ade6ed0c00197ee7cc1221aa84cfdf9c31defeb059a907", - "sha256:715ff2f2df46121071622063fc7543d9b1fd19ebfc4f5c8895af64a77a8c852c", - "sha256:73d785a950fc82dd2a25897d525d003f6378d1cb23ab305578394694202a58c3", - "sha256:8c8aaad94455178e3187ab22c8b01a3837f8ee50e09cf31f1ba129eb293ec30b", - "sha256:8ce678dbaf790dbdb3eba24056d5364fb45944f33553dd5869b7580cdbb83614", - "sha256:aaee9905aee35ba5905cfb3c62f3e83b3bec7b39413f0a7f19be4e547ea01ebb", - "sha256:bcd3b13b56ea479b3650b82cabd6b5343a625b0ced5429e4ccad28a8973f301b", - "sha256:c9e348e02e4d2b4a8b2eedb48210430658df6951fa484e59de33ff773fbd4b41", - "sha256:d205b1b46085271b4e15f670058ce182bd1199e56b317bf2ec004b6a44f911f6", - "sha256:d43943ef777f9a1c42bf4e552ba23ac77a6351de620aa9acf64ad54933ad4d34", - "sha256:d5d33e9e7af3b34a40dc05f498939f0ebf187f07c385fd58d591c533ad8562fe", - "sha256:fc0fea399acb12edbf8a628ba8d2312f583bdbdb3335635db062fa98cf71fca4", - "sha256:fe460b922ec15dd205595c9b5b99e2f056fd98ae8f9f56b888e7a17dc2b757e7" - ], - "markers": "python_version < '3.8'", - "version": "==1.4.1" - }, "unittest-xml-reporting": { "hashes": [ - "sha256:358bbdaf24a26d904cc1c26ef3078bca7fc81541e0a54c8961693cc96a6f35e0", - "sha256:9d28ddf6524cf0ff9293f61bd12e792de298f8561a5c945acea63fb437789e0e" + "sha256:74eaf7739a7957a74f52b8187c5616f61157372189bef0a32ba5c30bbc00e58a", + "sha256:e09b8ae70cce9904cdd331f53bf929150962869a5324ab7ff3dd6c8b87e01f7d" ], "index": "pypi", - "version": "==2.5.2" + "version": "==3.0.2" }, "urllib3": { "hashes": [ - "sha256:2393a695cd12afedd0dcb26fe5d50d0cf248e5a66f75dbd89a3d4eb333a61af4", - "sha256:a637e5fae88995b256e3409dc4d52c2e2e0ba32c42a6365fee8bbd2238de3cfb" + "sha256:2f3db8b19923a873b3e5256dc9c2dedfa883e33d87c690d9c7913e1f40673cdc", + "sha256:87716c2d2a7121198ebcb7ce7cccf6ce5e9ba539041cfbaeecfb641dc0bf6acc" ], - "index": "pypi", - "version": "==1.24.3" + "version": "==1.25.8" }, "virtualenv": { "hashes": [ - "sha256:0d62c70883c0342d59c11d0ddac0d954d0431321a41ab20851facf2b222598f3", - "sha256:55059a7a676e4e19498f1aad09b8313a38fcc0cdbe4fdddc0e9b06946d21b4bb" - ], - "version": "==16.7.9" - }, - "zipp": { - "hashes": [ - "sha256:ccc94ed0909b58ffe34430ea5451f07bc0c76467d7081619a454bf5c98b89e28", - "sha256:feae2f18633c32fc71f2de629bfb3bd3c9325cd4419642b1f1da42ee488d9b98" + "sha256:30ea90b21dabd11da5f509710ad3be2ae47d40ccbc717dfdd2efe4367c10f598", + "sha256:4a36a96d785428278edd389d9c36d763c5755844beb7509279194647b1ef47f1" ], - "version": "==2.1.0" + "version": "==20.0.7" } } } diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 0400ac4d2..16d1b7a2a 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -9,16 +9,18 @@ jobs: - job: test displayName: 'Lint & Test' pool: - vmImage: ubuntu-16.04 + vmImage: ubuntu-18.04 variables: PIP_CACHE_DIR: ".cache/pip" + PRE_COMMIT_HOME: $(Pipeline.Workspace)/pre-commit-cache steps: - task: UsePythonVersion@0 displayName: 'Set Python version' + name: PythonVersion inputs: - versionSpec: '3.7.x' + versionSpec: '3.8.x' addToPath: true - script: pip install pipenv @@ -27,10 +29,28 @@ jobs: - script: pipenv install --dev --deploy --system displayName: 'Install project using pipenv' - - script: python -m flake8 - displayName: 'Run linter' + # Create an executable shell script which replaces the original pipenv binary. + # The shell script ignores the first argument and executes the rest of the args as a command. + # It makes the `pipenv run flake8` command in the pre-commit hook work by circumventing + # pipenv entirely, which is too dumb to know it should use the system interpreter rather than + # creating a new venv. + - script: | + printf '%s\n%s' '#!/bin/bash' '"${@:2}"' > $(PythonVersion.pythonLocation)/bin/pipenv \ + && chmod +x $(PythonVersion.pythonLocation)/bin/pipenv + displayName: 'Mock pipenv binary' - - script: BOT_API_KEY=foo BOT_TOKEN=bar WOLFRAM_API_KEY=baz REDDIT_CLIENT_ID=spam REDDIT_SECRET=ham coverage run -m xmlrunner + - task: Cache@2 + displayName: 'Restore pre-commit environment' + inputs: + key: pre-commit | "$(PythonVersion.pythonLocation)" | .pre-commit-config.yaml + restoreKeys: | + pre-commit | "$(PythonVersion.pythonLocation)" + path: $(PRE_COMMIT_HOME) + + - script: pre-commit run --all-files + displayName: 'Run pre-commit hooks' + + - script: BOT_API_KEY=foo BOT_SENTRY_DSN=blah BOT_TOKEN=bar WOLFRAM_API_KEY=baz REDDIT_CLIENT_ID=spam REDDIT_SECRET=ham coverage run -m xmlrunner displayName: Run tests - script: coverage report -m && coverage xml -o coverage.xml diff --git a/bot/__init__.py b/bot/__init__.py index 789ace5c0..c9dbc3f40 100644 --- a/bot/__init__.py +++ b/bot/__init__.py @@ -1,14 +1,13 @@ import logging import os import sys -from logging import Logger, StreamHandler, handlers +from logging import Logger, handlers from pathlib import Path -from logmatic import JsonFormatter +import coloredlogs - -logging.TRACE = 5 -logging.addLevelName(logging.TRACE, "TRACE") +TRACE_LEVEL = logging.TRACE = 5 +logging.addLevelName(TRACE_LEVEL, "TRACE") def monkeypatch_trace(self: logging.Logger, msg: str, *args, **kwargs) -> None: @@ -20,75 +19,43 @@ def monkeypatch_trace(self: logging.Logger, msg: str, *args, **kwargs) -> None: logger.trace("Houston, we have an %s", "interesting problem", exc_info=1) """ - if self.isEnabledFor(logging.TRACE): - self._log(logging.TRACE, msg, args, **kwargs) + if self.isEnabledFor(TRACE_LEVEL): + self._log(TRACE_LEVEL, msg, args, **kwargs) Logger.trace = monkeypatch_trace -# Set up logging -logging_handlers = [] - -# We can't import this yet, so we have to define it ourselves -DEBUG_MODE = True if 'local' in os.environ.get("SITE_URL", "local") else False - -LOG_DIR = Path("logs") -LOG_DIR.mkdir(exist_ok=True) - -if DEBUG_MODE: - logging_handlers.append(StreamHandler(stream=sys.stdout)) - - json_handler = logging.FileHandler(filename=Path(LOG_DIR, "log.json"), mode="w") - json_handler.formatter = JsonFormatter() - logging_handlers.append(json_handler) -else: - - logfile = Path(LOG_DIR, "bot.log") - megabyte = 1048576 - - filehandler = handlers.RotatingFileHandler(logfile, maxBytes=(megabyte*5), backupCount=7) - logging_handlers.append(filehandler) - - json_handler = logging.StreamHandler(stream=sys.stdout) - json_handler.formatter = JsonFormatter() - logging_handlers.append(json_handler) - - -logging.basicConfig( - format="%(asctime)s Bot: | %(name)33s | %(levelname)8s | %(message)s", - datefmt="%b %d %H:%M:%S", - level=logging.TRACE if DEBUG_MODE else logging.INFO, - handlers=logging_handlers -) - -log = logging.getLogger(__name__) - +DEBUG_MODE = 'local' in os.environ.get("SITE_URL", "local") -for key, value in logging.Logger.manager.loggerDict.items(): - # Force all existing loggers to the correct level and handlers - # This happens long before we instantiate our loggers, so - # those should still have the expected level +log_level = TRACE_LEVEL if DEBUG_MODE else logging.INFO +format_string = "%(asctime)s | %(name)s | %(levelname)s | %(message)s" +log_format = logging.Formatter(format_string) - if key == "bot": - continue +log_file = Path("logs", "bot.log") +log_file.parent.mkdir(exist_ok=True) +file_handler = handlers.RotatingFileHandler(log_file, maxBytes=5242880, backupCount=7) +file_handler.setFormatter(log_format) - if not isinstance(value, logging.Logger): - # There might be some logging.PlaceHolder objects in there - continue +root_log = logging.getLogger() +root_log.setLevel(log_level) +root_log.addHandler(file_handler) - if DEBUG_MODE: - value.setLevel(logging.DEBUG) - else: - value.setLevel(logging.INFO) +if "COLOREDLOGS_LEVEL_STYLES" not in os.environ: + coloredlogs.DEFAULT_LEVEL_STYLES = { + **coloredlogs.DEFAULT_LEVEL_STYLES, + "trace": {"color": 246}, + "critical": {"background": "red"}, + "debug": coloredlogs.DEFAULT_LEVEL_STYLES["info"] + } - for handler in value.handlers.copy(): - value.removeHandler(handler) +if "COLOREDLOGS_LOG_FORMAT" not in os.environ: + coloredlogs.DEFAULT_LOG_FORMAT = format_string - for handler in logging_handlers: - value.addHandler(handler) +if "COLOREDLOGS_LOG_LEVEL" not in os.environ: + coloredlogs.DEFAULT_LOG_LEVEL = log_level +coloredlogs.install(logger=root_log, stream=sys.stdout) -# Silence irrelevant loggers -logging.getLogger("aio_pika").setLevel(logging.ERROR) -logging.getLogger("discord").setLevel(logging.ERROR) -logging.getLogger("websockets").setLevel(logging.ERROR) +logging.getLogger("discord").setLevel(logging.WARNING) +logging.getLogger("websockets").setLevel(logging.WARNING) +logging.getLogger(__name__) diff --git a/bot/__main__.py b/bot/__main__.py index 84bc7094b..3df477a6d 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -1,10 +1,23 @@ +import logging + import discord +import sentry_sdk from discord.ext.commands import when_mentioned_or +from sentry_sdk.integrations.logging import LoggingIntegration from bot import patches from bot.bot import Bot -from bot.constants import Bot as BotConfig, DEBUG_MODE +from bot.constants import Bot as BotConfig + +sentry_logging = LoggingIntegration( + level=logging.DEBUG, + event_level=logging.WARNING +) +sentry_sdk.init( + dsn=BotConfig.sentry_dsn, + integrations=[sentry_logging] +) bot = Bot( command_prefix=when_mentioned_or(BotConfig.prefix), @@ -18,6 +31,7 @@ bot.load_extension("bot.cogs.error_handler") bot.load_extension("bot.cogs.filtering") bot.load_extension("bot.cogs.logging") bot.load_extension("bot.cogs.security") +bot.load_extension("bot.cogs.config_verifier") # Commands, etc bot.load_extension("bot.cogs.antimalware") @@ -27,10 +41,8 @@ bot.load_extension("bot.cogs.clean") bot.load_extension("bot.cogs.extensions") bot.load_extension("bot.cogs.help") -# Only load this in production -if not DEBUG_MODE: - bot.load_extension("bot.cogs.doc") - bot.load_extension("bot.cogs.verification") +bot.load_extension("bot.cogs.doc") +bot.load_extension("bot.cogs.verification") # Feature cogs bot.load_extension("bot.cogs.alias") diff --git a/bot/api.py b/bot/api.py index 56db99828..4b8520582 100644 --- a/bot/api.py +++ b/bot/api.py @@ -32,6 +32,11 @@ class ResponseCodeError(ValueError): class APIClient: """Django Site API wrapper.""" + # These are class attributes so they can be seen when being mocked for tests. + # See commit 22a55534ef13990815a6f69d361e2a12693075d5 for details. + session: Optional[aiohttp.ClientSession] = None + loop: asyncio.AbstractEventLoop = None + def __init__(self, loop: asyncio.AbstractEventLoop, **kwargs): auth_headers = { 'Authorization': f"Token {Keys.site_api}" @@ -42,12 +47,12 @@ class APIClient: else: kwargs['headers'] = auth_headers - self.session: Optional[aiohttp.ClientSession] = None + self.session = None self.loop = loop self._ready = asyncio.Event(loop=loop) self._creation_task = None - self._session_args = kwargs + self._default_session_kwargs = kwargs self.recreate() @@ -55,25 +60,41 @@ class APIClient: def _url_for(endpoint: str) -> str: return f"{URLs.site_schema}{URLs.site_api}/{quote_url(endpoint)}" - async def _create_session(self) -> None: - """Create the aiohttp session and set the ready event.""" - self.session = aiohttp.ClientSession(**self._session_args) + async def _create_session(self, **session_kwargs) -> None: + """ + Create the aiohttp session with `session_kwargs` and set the ready event. + + `session_kwargs` is merged with `_default_session_kwargs` and overwrites its values. + If an open session already exists, it will first be closed. + """ + await self.close() + self.session = aiohttp.ClientSession(**{**self._default_session_kwargs, **session_kwargs}) self._ready.set() async def close(self) -> None: """Close the aiohttp session and unset the ready event.""" - if not self._ready.is_set(): - return + if self.session: + await self.session.close() - await self.session.close() self._ready.clear() - def recreate(self) -> None: - """Schedule the aiohttp session to be created if it's been closed.""" - if self.session is None or self.session.closed: + def recreate(self, force: bool = False, **session_kwargs) -> None: + """ + Schedule the aiohttp session to be created with `session_kwargs` if it's been closed. + + If `force` is True, the session will be recreated even if an open one exists. If a task to + create the session is pending, it will be cancelled. + + `session_kwargs` is merged with the kwargs given when the `APIClient` was created and + overwrites those default kwargs. + """ + if force or self.session is None or self.session.closed: + if force and self._creation_task: + self._creation_task.cancel() + # Don't schedule a task if one is already in progress. - if self._creation_task is None or self._creation_task.done(): - self._creation_task = self.loop.create_task(self._create_session()) + if force or self._creation_task is None or self._creation_task.done(): + self._creation_task = self.loop.create_task(self._create_session(**session_kwargs)) async def maybe_raise_for_status(self, response: aiohttp.ClientResponse, should_raise: bool) -> None: """Raise ResponseCodeError for non-OK response if an exception should be raised.""" @@ -85,43 +106,35 @@ class APIClient: response_text = await response.text() raise ResponseCodeError(response=response, response_text=response_text) - async def get(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict: - """Site API GET.""" + async def request(self, method: str, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict: + """Send an HTTP request to the site API and return the JSON response.""" await self._ready.wait() - async with self.session.get(self._url_for(endpoint), *args, **kwargs) as resp: + async with self.session.request(method.upper(), self._url_for(endpoint), **kwargs) as resp: await self.maybe_raise_for_status(resp, raise_for_status) return await resp.json() - async def patch(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict: - """Site API PATCH.""" - await self._ready.wait() + async def get(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict: + """Site API GET.""" + return await self.request("GET", endpoint, raise_for_status=raise_for_status, **kwargs) - async with self.session.patch(self._url_for(endpoint), *args, **kwargs) as resp: - await self.maybe_raise_for_status(resp, raise_for_status) - return await resp.json() + async def patch(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict: + """Site API PATCH.""" + return await self.request("PATCH", endpoint, raise_for_status=raise_for_status, **kwargs) - async def post(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict: + async def post(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict: """Site API POST.""" - await self._ready.wait() - - async with self.session.post(self._url_for(endpoint), *args, **kwargs) as resp: - await self.maybe_raise_for_status(resp, raise_for_status) - return await resp.json() + return await self.request("POST", endpoint, raise_for_status=raise_for_status, **kwargs) - async def put(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict: + async def put(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict: """Site API PUT.""" - await self._ready.wait() + return await self.request("PUT", endpoint, raise_for_status=raise_for_status, **kwargs) - async with self.session.put(self._url_for(endpoint), *args, **kwargs) as resp: - await self.maybe_raise_for_status(resp, raise_for_status) - return await resp.json() - - async def delete(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> Optional[dict]: + async def delete(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> Optional[dict]: """Site API DELETE.""" await self._ready.wait() - async with self.session.delete(self._url_for(endpoint), *args, **kwargs) as resp: + async with self.session.delete(self._url_for(endpoint), **kwargs) as resp: if resp.status == 204: return None @@ -141,77 +154,3 @@ def loop_is_running() -> bool: except RuntimeError: return False return True - - -class APILoggingHandler(logging.StreamHandler): - """Site API logging handler.""" - - def __init__(self, client: APIClient): - logging.StreamHandler.__init__(self) - self.client = client - - # internal batch of shipoff tasks that must not be scheduled - # on the event loop yet - scheduled when the event loop is ready. - self.queue = [] - - async def ship_off(self, payload: dict) -> None: - """Ship log payload to the logging API.""" - try: - await self.client.post('logs', json=payload) - except ResponseCodeError as err: - log.warning( - "Cannot send logging record to the site, got code %d.", - err.response.status, - extra={'via_handler': True} - ) - except Exception as err: - log.warning( - "Cannot send logging record to the site: %r", - err, - extra={'via_handler': True} - ) - - def emit(self, record: logging.LogRecord) -> None: - """ - Determine if a log record should be shipped to the logging API. - - If the asyncio event loop is not yet running, log records will instead be put in a queue - which will be consumed once the event loop is running. - - The following two conditions are set: - 1. Do not log anything below DEBUG (only applies to the monkeypatched `TRACE` level) - 2. Ignore log records originating from this logging handler itself to prevent infinite recursion - """ - if ( - record.levelno >= logging.DEBUG - and not record.__dict__.get('via_handler') - ): - payload = { - 'application': 'bot', - 'logger_name': record.name, - 'level': record.levelname.lower(), - 'module': record.module, - 'line': record.lineno, - 'message': self.format(record) - } - - task = self.ship_off(payload) - if not loop_is_running(): - self.queue.append(task) - else: - asyncio.create_task(task) - self.schedule_queued_tasks() - - def schedule_queued_tasks(self) -> None: - """Consume the queue and schedule the logging of each queued record.""" - for task in self.queue: - asyncio.create_task(task) - - if self.queue: - log.debug( - "Scheduled %d pending logging tasks.", - len(self.queue), - extra={'via_handler': True} - ) - - self.queue.clear() diff --git a/bot/bot.py b/bot/bot.py index 8f808272f..950ac6751 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -1,11 +1,15 @@ +import asyncio import logging import socket +import warnings from typing import Optional import aiohttp +import discord from discord.ext import commands from bot import api +from bot import constants log = logging.getLogger('bot') @@ -14,20 +18,20 @@ class Bot(commands.Bot): """A subclass of `discord.ext.commands.Bot` with an aiohttp session and an API client.""" def __init__(self, *args, **kwargs): - # Use asyncio for DNS resolution instead of threads so threads aren't spammed. - # Use AF_INET as its socket family to prevent HTTPS related problems both locally - # and in production. - self.connector = aiohttp.TCPConnector( - resolver=aiohttp.AsyncResolver(), - family=socket.AF_INET, - ) + if "connector" in kwargs: + warnings.warn( + "If login() is called (or the bot is started), the connector will be overwritten " + "with an internal one" + ) - super().__init__(*args, connector=self.connector, **kwargs) + super().__init__(*args, **kwargs) self.http_session: Optional[aiohttp.ClientSession] = None - self.api_client = api.APIClient(loop=self.loop, connector=self.connector) + self.api_client = api.APIClient(loop=self.loop) - log.addHandler(api.APILoggingHandler(self.api_client)) + self._connector = None + self._resolver = None + self._guild_available = asyncio.Event() def add_cog(self, cog: commands.Cog) -> None: """Adds a "cog" to the bot and logs the operation.""" @@ -35,19 +39,105 @@ class Bot(commands.Bot): log.info(f"Cog loaded: {cog.qualified_name}") def clear(self) -> None: - """Clears the internal state of the bot and resets the API client.""" + """ + Clears the internal state of the bot and recreates the connector and sessions. + + Will cause a DeprecationWarning if called outside a coroutine. + """ + # Because discord.py recreates the HTTPClient session, may as well follow suit and recreate + # our own stuff here too. + self._recreate() super().clear() - self.api_client.recreate() async def close(self) -> None: - """Close the aiohttp session after closing the Discord connection.""" + """Close the Discord connection and the aiohttp session, connector, and resolver.""" await super().close() - await self.http_session.close() await self.api_client.close() - async def start(self, *args, **kwargs) -> None: - """Open an aiohttp session before logging in and connecting to Discord.""" - self.http_session = aiohttp.ClientSession(connector=self.connector) + if self.http_session: + await self.http_session.close() + + if self._connector: + await self._connector.close() + + if self._resolver: + await self._resolver.close() + + async def login(self, *args, **kwargs) -> None: + """Re-create the connector and set up sessions before logging into Discord.""" + self._recreate() + await super().login(*args, **kwargs) + + def _recreate(self) -> None: + """Re-create the connector, aiohttp session, and the APIClient.""" + # Use asyncio for DNS resolution instead of threads so threads aren't spammed. + # Doesn't seem to have any state with regards to being closed, so no need to worry? + self._resolver = aiohttp.AsyncResolver() + + # Its __del__ does send a warning but it doesn't always show up for some reason. + if self._connector and not self._connector._closed: + log.warning( + "The previous connector was not closed; it will remain open and be overwritten" + ) + + # Use AF_INET as its socket family to prevent HTTPS related problems both locally + # and in production. + self._connector = aiohttp.TCPConnector( + resolver=self._resolver, + family=socket.AF_INET, + ) + + # Client.login() will call HTTPClient.static_login() which will create a session using + # this connector attribute. + self.http.connector = self._connector + + # Its __del__ does send a warning but it doesn't always show up for some reason. + if self.http_session and not self.http_session.closed: + log.warning( + "The previous session was not closed; it will remain open and be overwritten" + ) + + self.http_session = aiohttp.ClientSession(connector=self._connector) + self.api_client.recreate(force=True, connector=self._connector) + + async def on_guild_available(self, guild: discord.Guild) -> None: + """ + Set the internal guild available event when constants.Guild.id becomes available. + + If the cache appears to still be empty (no members, no channels, or no roles), the event + will not be set. + """ + if guild.id != constants.Guild.id: + return + + if not guild.roles or not guild.members or not guild.channels: + msg = "Guild available event was dispatched but the cache appears to still be empty!" + log.warning(msg) + + try: + webhook = await self.fetch_webhook(constants.Webhooks.dev_log) + except discord.HTTPException as e: + log.error(f"Failed to fetch webhook to send empty cache warning: status {e.status}") + else: + await webhook.send(f"<@&{constants.Roles.admin}> {msg}") + + return + + self._guild_available.set() + + async def on_guild_unavailable(self, guild: discord.Guild) -> None: + """Clear the internal guild available event when constants.Guild.id becomes unavailable.""" + if guild.id != constants.Guild.id: + return + + self._guild_available.clear() + + async def wait_until_guild_available(self) -> None: + """ + Wait until the constants.Guild.id guild is available (and the cache is ready). - await super().start(*args, **kwargs) + The on_ready event is inadequate because it only waits 2 seconds for a GUILD_CREATE + gateway event before giving up and thus not populating the cache for unavailable guilds. + """ + await self._guild_available.wait() diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py index 28e3e5d96..79bf486a4 100644 --- a/bot/cogs/antimalware.py +++ b/bot/cogs/antimalware.py @@ -1,10 +1,11 @@ import logging +from os.path import splitext from discord import Embed, Message, NotFound from discord.ext.commands import Cog from bot.bot import Bot -from bot.constants import AntiMalware as AntiMalwareConfig, Channels, URLs +from bot.constants import AntiMalware as AntiMalwareConfig, Channels, STAFF_ROLES, URLs log = logging.getLogger(__name__) @@ -18,28 +19,40 @@ class AntiMalware(Cog): @Cog.listener() async def on_message(self, message: Message) -> None: """Identify messages with prohibited attachments.""" - if not message.attachments: + # Return when message don't have attachment and don't moderate DMs + if not message.attachments or not message.guild: + return + + # Check if user is staff, if is, return + # Since we only care that roles exist to iterate over, check for the attr rather than a User/Member instance + if hasattr(message.author, "roles") and any(role.id in STAFF_ROLES for role in message.author.roles): return embed = Embed() - for attachment in message.attachments: - filename = attachment.filename.lower() - if filename.endswith('.py'): - embed.description = ( - f"It looks like you tried to attach a Python file - please " - f"use a code-pasting service such as {URLs.site_schema}{URLs.site_paste}" - ) - break # Other detections irrelevant because we prioritize the .py message. - if not filename.endswith(tuple(AntiMalwareConfig.whitelist)): - whitelisted_types = ', '.join(AntiMalwareConfig.whitelist) - meta_channel = self.bot.get_channel(Channels.meta) - embed.description = ( - f"It looks like you tried to attach a file type that we " - f"do not allow. We currently allow the following file " - f"types: **{whitelisted_types}**. \n\n Feel free to ask " - f"in {meta_channel.mention} if you think this is a mistake." - ) + file_extensions = {splitext(attachment.filename.lower())[1] for attachment in message.attachments} + extensions_blocked = file_extensions - set(AntiMalwareConfig.whitelist) + blocked_extensions_str = ', '.join(extensions_blocked) + if ".py" in extensions_blocked: + # Short-circuit on *.py files to provide a pastebin link + embed.description = ( + "It looks like you tried to attach a Python file - " + f"please use a code-pasting service such as {URLs.site_schema}{URLs.site_paste}" + ) + elif extensions_blocked: + whitelisted_types = ', '.join(AntiMalwareConfig.whitelist) + meta_channel = self.bot.get_channel(Channels.meta) + embed.description = ( + f"It looks like you tried to attach file type(s) that we do not allow ({blocked_extensions_str}). " + f"We currently allow the following file types: **{whitelisted_types}**.\n\n" + f"Feel free to ask in {meta_channel.mention} if you think this is a mistake." + ) + if embed.description: + log.info( + f"User '{message.author}' ({message.author.id}) uploaded blacklisted file(s): {blocked_extensions_str}", + extra={"attachment_list": [attachment.filename for attachment in message.attachments]} + ) + await message.channel.send(f"Hey {message.author.mention}!", embed=embed) # Delete the offending message: diff --git a/bot/cogs/antispam.py b/bot/cogs/antispam.py index f67ef6f05..baa6b9459 100644 --- a/bot/cogs/antispam.py +++ b/bot/cogs/antispam.py @@ -123,7 +123,7 @@ class AntiSpam(Cog): async def alert_on_validation_error(self) -> None: """Unloads the cog and alerts admins if configuration validation failed.""" - await self.bot.wait_until_ready() + await self.bot.wait_until_guild_available() if self.validation_errors: body = "**The following errors were encountered:**\n" body += "\n".join(f"- {error}" for error in self.validation_errors.values()) diff --git a/bot/cogs/bot.py b/bot/cogs/bot.py index 73b1e8f41..f17135877 100644 --- a/bot/cogs/bot.py +++ b/bot/cogs/bot.py @@ -34,13 +34,12 @@ class BotCog(Cog, name="Bot"): Channels.help_5: 0, Channels.help_6: 0, Channels.help_7: 0, - Channels.python: 0, + Channels.python_discussion: 0, } # These channels will also work, but will not be subject to cooldown self.channel_whitelist = ( - Channels.bot, - Channels.devtest, + Channels.bot_commands, ) # Stores improperly formatted Python codeblock message ids and the corresponding bot message diff --git a/bot/cogs/clean.py b/bot/cogs/clean.py index 2104efe57..5cdf0b048 100644 --- a/bot/cogs/clean.py +++ b/bot/cogs/clean.py @@ -173,7 +173,7 @@ class Clean(Cog): colour=Colour(Colours.soft_red), title="Bulk message delete", text=message, - channel_id=Channels.modlog, + channel_id=Channels.mod_log, ) @group(invoke_without_command=True, name="clean", aliases=["purge"]) diff --git a/bot/cogs/config_verifier.py b/bot/cogs/config_verifier.py new file mode 100644 index 000000000..d72c6c22e --- /dev/null +++ b/bot/cogs/config_verifier.py @@ -0,0 +1,40 @@ +import logging + +from discord.ext.commands import Cog + +from bot import constants +from bot.bot import Bot + + +log = logging.getLogger(__name__) + + +class ConfigVerifier(Cog): + """Verify config on startup.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.channel_verify_task = self.bot.loop.create_task(self.verify_channels()) + + async def verify_channels(self) -> None: + """ + Verify channels. + + If any channels in config aren't present in server, log them in a warning. + """ + await self.bot.wait_until_guild_available() + server = self.bot.get_guild(constants.Guild.id) + + server_channel_ids = {channel.id for channel in server.channels} + invalid_channels = [ + channel_name for channel_name, channel_id in constants.Channels + if channel_id not in server_channel_ids + ] + + if invalid_channels: + log.warning(f"Configured channels do not exist in server: {', '.join(invalid_channels)}.") + + +def setup(bot: Bot) -> None: + """Load the ConfigVerifier cog.""" + bot.add_cog(ConfigVerifier(bot)) diff --git a/bot/cogs/defcon.py b/bot/cogs/defcon.py index 3e7350fcc..cc0f79fe8 100644 --- a/bot/cogs/defcon.py +++ b/bot/cogs/defcon.py @@ -59,7 +59,7 @@ class Defcon(Cog): async def sync_settings(self) -> None: """On cog load, try to synchronize DEFCON settings to the API.""" - await self.bot.wait_until_ready() + await self.bot.wait_until_guild_available() self.channel = await self.bot.fetch_channel(Channels.defcon) try: @@ -68,20 +68,20 @@ class Defcon(Cog): except Exception: # Yikes! log.exception("Unable to get DEFCON settings!") - await self.bot.get_channel(Channels.devlog).send( - f"<@&{Roles.admin}> **WARNING**: Unable to get DEFCON settings!" + await self.bot.get_channel(Channels.dev_log).send( + f"<@&{Roles.admins}> **WARNING**: Unable to get DEFCON settings!" ) else: if data["enabled"]: self.enabled = True self.days = timedelta(days=data["days"]) - log.warning(f"DEFCON enabled: {self.days.days} days") + log.info(f"DEFCON enabled: {self.days.days} days") else: self.enabled = False self.days = timedelta(days=0) - log.warning(f"DEFCON disabled") + log.info(f"DEFCON disabled") await self.update_channel_topic() @@ -118,7 +118,7 @@ class Defcon(Cog): ) @group(name='defcon', aliases=('dc',), invoke_without_command=True) - @with_role(Roles.admin, Roles.owner) + @with_role(Roles.admins, Roles.owners) async def defcon_group(self, ctx: Context) -> None: """Check the DEFCON status or run a subcommand.""" await ctx.invoke(self.bot.get_command("help"), "defcon") @@ -146,7 +146,7 @@ class Defcon(Cog): await self.send_defcon_log(action, ctx.author, error) @defcon_group.command(name='enable', aliases=('on', 'e')) - @with_role(Roles.admin, Roles.owner) + @with_role(Roles.admins, Roles.owners) async def enable_command(self, ctx: Context) -> None: """ Enable DEFCON mode. Useful in a pinch, but be sure you know what you're doing! @@ -159,7 +159,7 @@ class Defcon(Cog): await self.update_channel_topic() @defcon_group.command(name='disable', aliases=('off', 'd')) - @with_role(Roles.admin, Roles.owner) + @with_role(Roles.admins, Roles.owners) async def disable_command(self, ctx: Context) -> None: """Disable DEFCON mode. Useful in a pinch, but be sure you know what you're doing!""" self.enabled = False @@ -167,7 +167,7 @@ class Defcon(Cog): await self.update_channel_topic() @defcon_group.command(name='status', aliases=('s',)) - @with_role(Roles.admin, Roles.owner) + @with_role(Roles.admins, Roles.owners) async def status_command(self, ctx: Context) -> None: """Check the current status of DEFCON mode.""" embed = Embed( @@ -179,7 +179,7 @@ class Defcon(Cog): await ctx.send(embed=embed) @defcon_group.command(name='days') - @with_role(Roles.admin, Roles.owner) + @with_role(Roles.admins, Roles.owners) async def days_command(self, ctx: Context, days: int) -> None: """Set how old an account must be to join the server, in days, with DEFCON mode enabled.""" self.days = timedelta(days=days) diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index 6e7c00b6a..204cffb37 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -157,7 +157,7 @@ class Doc(commands.Cog): async def init_refresh_inventory(self) -> None: """Refresh documentation inventory on cog initialization.""" - await self.bot.wait_until_ready() + await self.bot.wait_until_guild_available() await self.refresh_inventory() async def update_single( diff --git a/bot/cogs/duck_pond.py b/bot/cogs/duck_pond.py index 345d2856c..1f84a0609 100644 --- a/bot/cogs/duck_pond.py +++ b/bot/cogs/duck_pond.py @@ -22,7 +22,7 @@ class DuckPond(Cog): async def fetch_webhook(self) -> None: """Fetches the webhook object, so we can post to it.""" - await self.bot.wait_until_ready() + await self.bot.wait_until_guild_available() try: self.webhook = await self.bot.fetch_webhook(self.webhook_id) diff --git a/bot/cogs/error_handler.py b/bot/cogs/error_handler.py index 52893b2ee..261769efc 100644 --- a/bot/cogs/error_handler.py +++ b/bot/cogs/error_handler.py @@ -1,24 +1,14 @@ import contextlib import logging +import typing as t -from discord.ext.commands import ( - BadArgument, - BotMissingPermissions, - CheckFailure, - CommandError, - CommandInvokeError, - CommandNotFound, - CommandOnCooldown, - DisabledCommand, - MissingPermissions, - NoPrivateMessage, - UserInputError, -) -from discord.ext.commands import Cog, Context +from discord.ext.commands import Cog, Command, Context, errors +from sentry_sdk import push_scope from bot.api import ResponseCodeError from bot.bot import Bot from bot.constants import Channels +from bot.converters import TagNameConverter from bot.decorators import InChannelCheckFailure log = logging.getLogger(__name__) @@ -31,126 +21,209 @@ class ErrorHandler(Cog): self.bot = bot @Cog.listener() - async def on_command_error(self, ctx: Context, e: CommandError) -> None: + async def on_command_error(self, ctx: Context, e: errors.CommandError) -> None: """ Provide generic command error handling. - Error handling is deferred to any local error handler, if present. - - Error handling emits a single error response, prioritized as follows: - 1. If the name fails to match a command but matches a tag, the tag is invoked - 2. Send a BadArgument error message to the invoking context & invoke the command's help - 3. Send a UserInputError error message to the invoking context & invoke the command's help - 4. Send a NoPrivateMessage error message to the invoking context - 5. Send a BotMissingPermissions error message to the invoking context - 6. Log a MissingPermissions error, no message is sent - 7. Send a InChannelCheckFailure error message to the invoking context - 8. Log CheckFailure, CommandOnCooldown, and DisabledCommand errors, no message is sent - 9. For CommandInvokeErrors, response is based on the type of error: - * 404: Error message is sent to the invoking context - * 400: Log the resopnse JSON, no message is sent - * 500 <= status <= 600: Error message is sent to the invoking context - 10. Otherwise, handling is deferred to `handle_unexpected_error` + Error handling is deferred to any local error handler, if present. This is done by + checking for the presence of a `handled` attribute on the error. + + Error handling emits a single error message in the invoking context `ctx` and a log message, + prioritised as follows: + + 1. If the name fails to match a command but matches a tag, the tag is invoked + * If CommandNotFound is raised when invoking the tag (determined by the presence of the + `invoked_from_error_handler` attribute), this error is treated as being unexpected + and therefore sends an error message + * Commands in the verification channel are ignored + 2. UserInputError: see `handle_user_input_error` + 3. CheckFailure: see `handle_check_failure` + 4. CommandOnCooldown: send an error message in the invoking context + 5. ResponseCodeError: see `handle_api_error` + 6. Otherwise, if not a DisabledCommand, handling is deferred to `handle_unexpected_error` """ command = ctx.command - parent = None + if hasattr(e, "handled"): + log.trace(f"Command {command} had its error already handled locally; ignoring.") + return + + # Try to look for a tag with the command's name if the command isn't found. + if isinstance(e, errors.CommandNotFound) and not hasattr(ctx, "invoked_from_error_handler"): + if ctx.channel.id != Channels.verification: + await self.try_get_tag(ctx) + return # Exit early to avoid logging. + elif isinstance(e, errors.UserInputError): + await self.handle_user_input_error(ctx, e) + elif isinstance(e, errors.CheckFailure): + await self.handle_check_failure(ctx, e) + elif isinstance(e, errors.CommandOnCooldown): + await ctx.send(e) + elif isinstance(e, errors.CommandInvokeError): + if isinstance(e.original, ResponseCodeError): + await self.handle_api_error(ctx, e.original) + else: + await self.handle_unexpected_error(ctx, e.original) + return # Exit early to avoid logging. + elif not isinstance(e, errors.DisabledCommand): + # ConversionError, MaxConcurrencyReached, ExtensionError + await self.handle_unexpected_error(ctx, e) + return # Exit early to avoid logging. + + log.debug( + f"Command {command} invoked by {ctx.message.author} with error " + f"{e.__class__.__name__}: {e}" + ) + + async def get_help_command(self, command: t.Optional[Command]) -> t.Tuple: + """Return the help command invocation args to display help for `command`.""" + parent = None if command is not None: parent = command.parent # Retrieve the help command for the invoked command. if parent and command: - help_command = (self.bot.get_command("help"), parent.name, command.name) + return self.bot.get_command("help"), parent.name, command.name elif command: - help_command = (self.bot.get_command("help"), command.name) + return self.bot.get_command("help"), command.name else: - help_command = (self.bot.get_command("help"),) + return self.bot.get_command("help") - if hasattr(e, "handled"): - log.trace(f"Command {command} had its error already handled locally; ignoring.") + async def try_get_tag(self, ctx: Context) -> None: + """ + Attempt to display a tag by interpreting the command name as a tag name. + + The invocation of tags get respects its checks. Any CommandErrors raised will be handled + by `on_command_error`, but the `invoked_from_error_handler` attribute will be added to + the context to prevent infinite recursion in the case of a CommandNotFound exception. + """ + tags_get_command = self.bot.get_command("tags get") + ctx.invoked_from_error_handler = True + + log_msg = "Cancelling attempt to fall back to a tag due to failed checks." + try: + if not await tags_get_command.can_run(ctx): + log.debug(log_msg) + return + except errors.CommandError as tag_error: + log.debug(log_msg) + await self.on_command_error(ctx, tag_error) return - # Try to look for a tag with the command's name if the command isn't found. - if isinstance(e, CommandNotFound) and not hasattr(ctx, "invoked_from_error_handler"): - if not ctx.channel.id == Channels.verification: - tags_get_command = self.bot.get_command("tags get") - ctx.invoked_from_error_handler = True - - log_msg = "Cancelling attempt to fall back to a tag due to failed checks." - try: - if not await tags_get_command.can_run(ctx): - log.debug(log_msg) - return - except CommandError as tag_error: - log.debug(log_msg) - await self.on_command_error(ctx, tag_error) - return - - # Return to not raise the exception - with contextlib.suppress(ResponseCodeError): - await ctx.invoke(tags_get_command, tag_name=ctx.invoked_with) - return - elif isinstance(e, BadArgument): + try: + tag_name = await TagNameConverter.convert(ctx, ctx.invoked_with) + except errors.BadArgument: + log.debug( + f"{ctx.author} tried to use an invalid command " + f"and the fallback tag failed validation in TagNameConverter." + ) + else: + with contextlib.suppress(ResponseCodeError): + await ctx.invoke(tags_get_command, tag_name=tag_name) + # Return to not raise the exception + return + + async def handle_user_input_error(self, ctx: Context, e: errors.UserInputError) -> None: + """ + Send an error message in `ctx` for UserInputError, sometimes invoking the help command too. + + * MissingRequiredArgument: send an error message with arg name and the help command + * TooManyArguments: send an error message and the help command + * BadArgument: send an error message and the help command + * BadUnionArgument: send an error message including the error produced by the last converter + * ArgumentParsingError: send an error message + * Other: send an error message and the help command + """ + # TODO: use ctx.send_help() once PR #519 is merged. + help_command = await self.get_help_command(ctx.command) + + if isinstance(e, errors.MissingRequiredArgument): + await ctx.send(f"Missing required argument `{e.param.name}`.") + await ctx.invoke(*help_command) + elif isinstance(e, errors.TooManyArguments): + await ctx.send(f"Too many arguments provided.") + await ctx.invoke(*help_command) + elif isinstance(e, errors.BadArgument): await ctx.send(f"Bad argument: {e}\n") await ctx.invoke(*help_command) - elif isinstance(e, UserInputError): + elif isinstance(e, errors.BadUnionArgument): + await ctx.send(f"Bad argument: {e}\n```{e.errors[-1]}```") + elif isinstance(e, errors.ArgumentParsingError): + await ctx.send(f"Argument parsing error: {e}") + else: await ctx.send("Something about your input seems off. Check the arguments:") await ctx.invoke(*help_command) - log.debug( - f"Command {command} invoked by {ctx.message.author} with error " - f"{e.__class__.__name__}: {e}" - ) - elif isinstance(e, NoPrivateMessage): - await ctx.send("Sorry, this command can't be used in a private message!") - elif isinstance(e, BotMissingPermissions): - await ctx.send(f"Sorry, it looks like I don't have the permissions I need to do that.") - log.warning( - f"The bot is missing permissions to execute command {command}: {e.missing_perms}" - ) - elif isinstance(e, MissingPermissions): - log.debug( - f"{ctx.message.author} is missing permissions to invoke command {command}: " - f"{e.missing_perms}" + + @staticmethod + async def handle_check_failure(ctx: Context, e: errors.CheckFailure) -> None: + """ + Send an error message in `ctx` for certain types of CheckFailure. + + The following types are handled: + + * BotMissingPermissions + * BotMissingRole + * BotMissingAnyRole + * NoPrivateMessage + * InChannelCheckFailure + """ + bot_missing_errors = ( + errors.BotMissingPermissions, + errors.BotMissingRole, + errors.BotMissingAnyRole + ) + + if isinstance(e, bot_missing_errors): + await ctx.send( + f"Sorry, it looks like I don't have the permissions or roles I need to do that." ) - elif isinstance(e, InChannelCheckFailure): + elif isinstance(e, (InChannelCheckFailure, errors.NoPrivateMessage)): await ctx.send(e) - elif isinstance(e, (CheckFailure, CommandOnCooldown, DisabledCommand)): - log.debug( - f"Command {command} invoked by {ctx.message.author} with error " - f"{e.__class__.__name__}: {e}" - ) - elif isinstance(e, CommandInvokeError): - if isinstance(e.original, ResponseCodeError): - status = e.original.response.status - - if status == 404: - await ctx.send("There does not seem to be anything matching your query.") - elif status == 400: - content = await e.original.response.json() - log.debug(f"API responded with 400 for command {command}: %r.", content) - await ctx.send("According to the API, your request is malformed.") - elif 500 <= status < 600: - await ctx.send("Sorry, there seems to be an internal issue with the API.") - log.warning(f"API responded with {status} for command {command}") - else: - await ctx.send(f"Got an unexpected status code from the API (`{status}`).") - log.warning(f"Unexpected API response for command {command}: {status}") - else: - await self.handle_unexpected_error(ctx, e.original) + + @staticmethod + async def handle_api_error(ctx: Context, e: ResponseCodeError) -> None: + """Send an error message in `ctx` for ResponseCodeError and log it.""" + if e.status == 404: + await ctx.send("There does not seem to be anything matching your query.") + log.debug(f"API responded with 404 for command {ctx.command}") + elif e.status == 400: + content = await e.response.json() + log.debug(f"API responded with 400 for command {ctx.command}: %r.", content) + await ctx.send("According to the API, your request is malformed.") + elif 500 <= e.status < 600: + await ctx.send("Sorry, there seems to be an internal issue with the API.") + log.warning(f"API responded with {e.status} for command {ctx.command}") else: - await self.handle_unexpected_error(ctx, e) + await ctx.send(f"Got an unexpected status code from the API (`{e.status}`).") + log.warning(f"Unexpected API response for command {ctx.command}: {e.status}") @staticmethod - async def handle_unexpected_error(ctx: Context, e: CommandError) -> None: - """Generic handler for errors without an explicit handler.""" + async def handle_unexpected_error(ctx: Context, e: errors.CommandError) -> None: + """Send a generic error message in `ctx` and log the exception as an error with exc_info.""" await ctx.send( f"Sorry, an unexpected error occurred. Please let us know!\n\n" f"```{e.__class__.__name__}: {e}```" ) - log.error( - f"Error executing command invoked by {ctx.message.author}: {ctx.message.content}" - ) - raise e + + with push_scope() as scope: + scope.user = { + "id": ctx.author.id, + "username": str(ctx.author) + } + + scope.set_tag("command", ctx.command.qualified_name) + scope.set_tag("message_id", ctx.message.id) + scope.set_tag("channel_id", ctx.channel.id) + + scope.set_extra("full_message", ctx.message.content) + + if ctx.guild is not None: + scope.set_extra( + "jump_to", + f"https://discordapp.com/channels/{ctx.guild.id}/{ctx.channel.id}/{ctx.message.id}" + ) + + log.error(f"Error executing command invoked by {ctx.message.author}: {ctx.message.content}", exc_info=e) def setup(bot: Bot) -> None: diff --git a/bot/cogs/eval.py b/bot/cogs/eval.py index 9c729f28a..52136fc8d 100644 --- a/bot/cogs/eval.py +++ b/bot/cogs/eval.py @@ -174,14 +174,14 @@ async def func(): # (None,) -> Any await ctx.send(f"```py\n{out}```", embed=embed) @group(name='internal', aliases=('int',)) - @with_role(Roles.owner, Roles.admin) + @with_role(Roles.owners, Roles.admins) async def internal_group(self, ctx: Context) -> None: """Internal commands. Top secret!""" if not ctx.invoked_subcommand: await ctx.invoke(self.bot.get_command("help"), "internal") @internal_group.command(name='eval', aliases=('e',)) - @with_role(Roles.admin, Roles.owner) + @with_role(Roles.admins, Roles.owners) async def eval(self, ctx: Context, *, code: str) -> None: """Run eval in a REPL-like format.""" code = code.strip("`") diff --git a/bot/cogs/extensions.py b/bot/cogs/extensions.py index f16e79fb7..fb6cd9aa3 100644 --- a/bot/cogs/extensions.py +++ b/bot/cogs/extensions.py @@ -69,7 +69,7 @@ class Extensions(commands.Cog): @extensions_group.command(name="load", aliases=("l",)) async def load_command(self, ctx: Context, *extensions: Extension) -> None: - """ + r""" Load extensions given their fully qualified or unqualified names. If '\*' or '\*\*' is given as the name, all unloaded extensions will be loaded. @@ -86,7 +86,7 @@ class Extensions(commands.Cog): @extensions_group.command(name="unload", aliases=("ul",)) async def unload_command(self, ctx: Context, *extensions: Extension) -> None: - """ + r""" Unload currently loaded extensions given their fully qualified or unqualified names. If '\*' or '\*\*' is given as the name, all loaded extensions will be unloaded. @@ -109,7 +109,7 @@ class Extensions(commands.Cog): @extensions_group.command(name="reload", aliases=("r",)) async def reload_command(self, ctx: Context, *extensions: Extension) -> None: - """ + r""" Reload extensions given their fully qualified or unqualified names. If an extension fails to be reloaded, it will be rolled-back to the prior working state. @@ -221,7 +221,7 @@ class Extensions(commands.Cog): # This cannot be static (must have a __func__ attribute). def cog_check(self, ctx: Context) -> bool: """Only allow moderators and core developers to invoke the commands in this cog.""" - return with_role_check(ctx, *MODERATION_ROLES, Roles.core_developer) + return with_role_check(ctx, *MODERATION_ROLES, Roles.core_developers) # This cannot be static (must have a __func__ attribute). async def cog_command_error(self, ctx: Context, error: Exception) -> None: diff --git a/bot/cogs/free.py b/bot/cogs/free.py index 49cab6172..33b55e79a 100644 --- a/bot/cogs/free.py +++ b/bot/cogs/free.py @@ -22,7 +22,7 @@ class Free(Cog): PYTHON_HELP_ID = Categories.python_help @command(name="free", aliases=('f',)) - @redirect_output(destination_channel=Channels.bot, bypass_roles=STAFF_ROLES) + @redirect_output(destination_channel=Channels.bot_commands, bypass_roles=STAFF_ROLES) async def free(self, ctx: Context, user: Member = None, seek: int = 2) -> None: """ Lists free help channels by likeliness of availability. @@ -55,7 +55,7 @@ class Free(Cog): msg = messages[seek - 1] # Otherwise get last message else: - msg = await channel.history(limit=1).next() # noqa (False positive) + msg = await channel.history(limit=1).next() # noqa: B305 inactive = (datetime.utcnow() - msg.created_at).seconds if inactive > TIMEOUT: diff --git a/bot/cogs/information.py b/bot/cogs/information.py index 125d7ce24..7921a4932 100644 --- a/bot/cogs/information.py +++ b/bot/cogs/information.py @@ -2,19 +2,18 @@ import colorsys import logging import pprint import textwrap -import typing -from collections import defaultdict -from typing import Any, Mapping, Optional - -import discord -from discord import CategoryChannel, Colour, Embed, Member, Role, TextChannel, VoiceChannel, utils -from discord.ext import commands -from discord.ext.commands import BucketType, Cog, Context, command, group +from collections import Counter, defaultdict +from string import Template +from typing import Any, Mapping, Optional, Union + +from discord import Colour, Embed, Member, Message, Role, Status, utils +from discord.ext.commands import BucketType, Cog, Context, Paginator, command, group from discord.utils import escape_markdown from bot import constants from bot.bot import Bot from bot.decorators import InChannelCheckFailure, in_channel, with_role +from bot.pagination import LinePaginator from bot.utils.checks import cooldown_with_role_bypass, with_role_check from bot.utils.time import time_since @@ -32,34 +31,31 @@ class Information(Cog): async def roles_info(self, ctx: Context) -> None: """Returns a list of all roles and their corresponding IDs.""" # Sort the roles alphabetically and remove the @everyone role - roles = sorted(ctx.guild.roles, key=lambda role: role.name) - roles = [role for role in roles if role.name != "@everyone"] + roles = sorted(ctx.guild.roles[1:], key=lambda role: role.name) - # Build a string - role_string = "" + # Build a list + role_list = [] for role in roles: - role_string += f"`{role.id}` - {role.mention}\n" + role_list.append(f"`{role.id}` - {role.mention}") # Build an embed embed = Embed( - title="Role information", - colour=Colour.blurple(), - description=role_string + title=f"Role information (Total {len(roles)} role{'s' * (len(role_list) > 1)})", + colour=Colour.blurple() ) - embed.set_footer(text=f"Total roles: {len(roles)}") - - await ctx.send(embed=embed) + await LinePaginator.paginate(role_list, ctx, embed, empty=False) @with_role(*constants.MODERATION_ROLES) @command(name="role") - async def role_info(self, ctx: Context, *roles: typing.Union[Role, str]) -> None: + async def role_info(self, ctx: Context, *roles: Union[Role, str]) -> None: """ Return information on a role or list of roles. To specify multiple roles just add to the arguments, delimit roles with spaces in them using quotation marks. """ parsed_roles = [] + failed_roles = [] for role_name in roles: if isinstance(role_name, Role): @@ -70,29 +66,29 @@ class Information(Cog): role = utils.find(lambda r: r.name.lower() == role_name.lower(), ctx.guild.roles) if not role: - await ctx.send(f":x: Could not convert `{role_name}` to a role") + failed_roles.append(role_name) continue parsed_roles.append(role) + if failed_roles: + await ctx.send( + ":x: I could not convert the following role names to a role: \n- " + "\n- ".join(failed_roles) + ) + for role in parsed_roles: + h, s, v = colorsys.rgb_to_hsv(*role.colour.to_rgb()) + embed = Embed( title=f"{role.name} info", colour=role.colour, ) - embed.add_field(name="ID", value=role.id, inline=True) - embed.add_field(name="Colour (RGB)", value=f"#{role.colour.value:0>6x}", inline=True) - - h, s, v = colorsys.rgb_to_hsv(*role.colour.to_rgb()) - embed.add_field(name="Colour (HSV)", value=f"{h:.2f} {s:.2f} {v}", inline=True) - embed.add_field(name="Member count", value=len(role.members), inline=True) - embed.add_field(name="Position", value=role.position) - embed.add_field(name="Permission code", value=role.permissions.value, inline=True) await ctx.send(embed=embed) @@ -104,40 +100,23 @@ class Information(Cog): features = ", ".join(ctx.guild.features) region = ctx.guild.region - # How many of each type of channel? roles = len(ctx.guild.roles) - channels = ctx.guild.channels - text_channels = 0 - category_channels = 0 - voice_channels = 0 - for channel in channels: - if type(channel) == TextChannel: - text_channels += 1 - elif type(channel) == CategoryChannel: - category_channels += 1 - elif type(channel) == VoiceChannel: - voice_channels += 1 - - # How many of each user status? member_count = ctx.guild.member_count - members = ctx.guild.members - online = 0 - dnd = 0 - idle = 0 - offline = 0 - for member in members: - if str(member.status) == "online": - online += 1 - elif str(member.status) == "offline": - offline += 1 - elif str(member.status) == "idle": - idle += 1 - elif str(member.status) == "dnd": - dnd += 1 - embed = Embed( - colour=Colour.blurple(), - description=textwrap.dedent(f""" + # How many of each type of channel? + channels = Counter(c.type for c in ctx.guild.channels) + channel_counts = "".join(sorted(f"{str(ch).title()} channels: {channels[ch]}\n" for ch in channels)).strip() + + # How many of each user status? + statuses = Counter(member.status for member in ctx.guild.members) + embed = Embed(colour=Colour.blurple()) + + # Because channel_counts lacks leading whitespace, it breaks the dedent if it's inserted directly by the + # f-string. While this is correctly formated by Discord, it makes unit testing difficult. To keep the formatting + # without joining a tuple of strings we can use a Template string to insert the already-formatted channel_counts + # after the dedent is made. + embed.description = Template( + textwrap.dedent(f""" **Server information** Created: {created} Voice region: {region} @@ -146,18 +125,15 @@ class Information(Cog): **Counts** Members: {member_count:,} Roles: {roles} - Text: {text_channels} - Voice: {voice_channels} - Channel categories: {category_channels} + $channel_counts **Members** - {constants.Emojis.status_online} {online} - {constants.Emojis.status_idle} {idle} - {constants.Emojis.status_dnd} {dnd} - {constants.Emojis.status_offline} {offline} + {constants.Emojis.status_online} {statuses[Status.online]:,} + {constants.Emojis.status_idle} {statuses[Status.idle]:,} + {constants.Emojis.status_dnd} {statuses[Status.dnd]:,} + {constants.Emojis.status_offline} {statuses[Status.offline]:,} """) - ) - + ).substitute({"channel_counts": channel_counts}) embed.set_thumbnail(url=ctx.guild.icon_url) await ctx.send(embed=embed) @@ -169,14 +145,14 @@ class Information(Cog): user = ctx.author # Do a role check if this is being executed on someone other than the caller - if user != ctx.author and not with_role_check(ctx, *constants.MODERATION_ROLES): + elif user != ctx.author and not with_role_check(ctx, *constants.MODERATION_ROLES): await ctx.send("You may not use this command on users other than yourself.") return # Non-staff may only do this in #bot-commands if not with_role_check(ctx, *constants.STAFF_ROLES): - if not ctx.channel.id == constants.Channels.bot: - raise InChannelCheckFailure(constants.Channels.bot) + if not ctx.channel.id == constants.Channels.bot_commands: + raise InChannelCheckFailure(constants.Channels.bot_commands) embed = await self.create_user_embed(ctx, user) @@ -202,7 +178,7 @@ class Information(Cog): name = f"{user.nick} ({name})" joined = time_since(user.joined_at, precision="days") - roles = ", ".join(role.mention for role in user.roles if role.name != "@everyone") + roles = ", ".join(role.mention for role in user.roles[1:]) description = [ textwrap.dedent(f""" @@ -355,14 +331,14 @@ class Information(Cog): @cooldown_with_role_bypass(2, 60 * 3, BucketType.member, bypass_roles=constants.STAFF_ROLES) @group(invoke_without_command=True) - @in_channel(constants.Channels.bot, bypass_roles=constants.STAFF_ROLES) - async def raw(self, ctx: Context, *, message: discord.Message, json: bool = False) -> None: + @in_channel(constants.Channels.bot_commands, bypass_roles=constants.STAFF_ROLES) + async def raw(self, ctx: Context, *, message: Message, json: bool = False) -> None: """Shows information about the raw API response.""" # I *guess* it could be deleted right as the command is invoked but I felt like it wasn't worth handling # doing this extra request is also much easier than trying to convert everything back into a dictionary again raw_data = await ctx.bot.http.get_message(message.channel.id, message.id) - paginator = commands.Paginator() + paginator = Paginator() def add_content(title: str, content: str) -> None: paginator.add_line(f'== {title} ==\n') @@ -390,7 +366,7 @@ class Information(Cog): await ctx.send(page) @raw.command() - async def json(self, ctx: Context, message: discord.Message) -> None: + async def json(self, ctx: Context, message: Message) -> None: """Shows information about the raw API response in a copy-pasteable Python format.""" await ctx.invoke(self.raw, message=message, json=True) diff --git a/bot/cogs/jams.py b/bot/cogs/jams.py index 985f28ce5..1d062b0c2 100644 --- a/bot/cogs/jams.py +++ b/bot/cogs/jams.py @@ -18,7 +18,7 @@ class CodeJams(commands.Cog): self.bot = bot @commands.command() - @with_role(Roles.admin) + @with_role(Roles.admins) async def createteam(self, ctx: commands.Context, team_name: str, members: commands.Greedy[Member]) -> None: """ Create team channels (voice and text) in the Code Jams category, assign roles, and add overwrites for the team. @@ -95,10 +95,10 @@ class CodeJams(commands.Cog): ) # Assign team leader role - await members[0].add_roles(ctx.guild.get_role(Roles.team_leader)) + await members[0].add_roles(ctx.guild.get_role(Roles.team_leaders)) # Assign rest of roles - jammer_role = ctx.guild.get_role(Roles.jammer) + jammer_role = ctx.guild.get_role(Roles.jammers) for member in members: await member.add_roles(jammer_role) diff --git a/bot/cogs/logging.py b/bot/cogs/logging.py index d1b7dcab3..94fa2b139 100644 --- a/bot/cogs/logging.py +++ b/bot/cogs/logging.py @@ -20,7 +20,7 @@ class Logging(Cog): async def startup_greeting(self) -> None: """Announce our presence to the configured devlog channel.""" - await self.bot.wait_until_ready() + await self.bot.wait_until_guild_available() log.info("Bot connected!") embed = Embed(description="Connected!") @@ -34,7 +34,7 @@ class Logging(Cog): ) if not DEBUG_MODE: - await self.bot.get_channel(Channels.devlog).send(embed=embed) + await self.bot.get_channel(Channels.dev_log).send(embed=embed) def setup(bot: Bot) -> None: diff --git a/bot/cogs/moderation/infractions.py b/bot/cogs/moderation/infractions.py index f4e296df9..9ea17b2b3 100644 --- a/bot/cogs/moderation/infractions.py +++ b/bot/cogs/moderation/infractions.py @@ -313,6 +313,6 @@ class Infractions(InfractionScheduler, commands.Cog): async def cog_command_error(self, ctx: Context, error: Exception) -> None: """Send a notification to the invoking context on a Union failure.""" if isinstance(error, commands.BadUnionArgument): - if discord.User in error.converters: + if discord.User in error.converters or discord.Member in error.converters: await ctx.send(str(error.errors[0])) error.handled = True diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index 0636422d3..35448f682 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -1,4 +1,3 @@ -import asyncio import logging import textwrap import typing as t @@ -129,9 +128,13 @@ class ModManagement(commands.Cog): # Re-schedule infraction if the expiration has been updated if 'expires_at' in request_data: - self.infractions_cog.cancel_task(new_infraction['id']) - loop = asyncio.get_event_loop() - self.infractions_cog.schedule_task(loop, new_infraction['id'], new_infraction) + # A scheduled task should only exist if the old infraction wasn't permanent + if old_infraction['expires_at']: + self.infractions_cog.cancel_task(new_infraction['id']) + + # If the infraction was not marked as permanent, schedule a new expiration task + if request_data['expires_at']: + self.infractions_cog.schedule_task(new_infraction['id'], new_infraction) log_text += f""" Previous expiry: {old_infraction['expires_at'] or "Permanent"} diff --git a/bot/cogs/moderation/modlog.py b/bot/cogs/moderation/modlog.py index e8ae0dbe6..81d95298d 100644 --- a/bot/cogs/moderation/modlog.py +++ b/bot/cogs/moderation/modlog.py @@ -67,7 +67,7 @@ class ModLog(Cog, name="ModLog"): 'embeds': [embed.to_dict() for embed in message.embeds], 'attachments': attachment, } - for message, attachment in zip_longest(messages, attachments) + for message, attachment in zip_longest(messages, attachments, fillvalue=[]) ] } ) @@ -87,7 +87,7 @@ class ModLog(Cog, name="ModLog"): title: t.Optional[str], text: str, thumbnail: t.Optional[t.Union[str, discord.Asset]] = None, - channel_id: int = Channels.modlog, + channel_id: int = Channels.mod_log, ping_everyone: bool = False, files: t.Optional[t.List[discord.File]] = None, content: t.Optional[str] = None, @@ -377,7 +377,7 @@ class ModLog(Cog, name="ModLog"): Icons.user_ban, Colours.soft_red, "User banned", f"{member} (`{member.id}`)", thumbnail=member.avatar_url_as(static_format="png"), - channel_id=Channels.userlog + channel_id=Channels.user_log ) @Cog.listener() @@ -399,7 +399,7 @@ class ModLog(Cog, name="ModLog"): Icons.sign_in, Colours.soft_green, "User joined", message, thumbnail=member.avatar_url_as(static_format="png"), - channel_id=Channels.userlog + channel_id=Channels.user_log ) @Cog.listener() @@ -416,7 +416,7 @@ class ModLog(Cog, name="ModLog"): Icons.sign_out, Colours.soft_red, "User left", f"{member} (`{member.id}`)", thumbnail=member.avatar_url_as(static_format="png"), - channel_id=Channels.userlog + channel_id=Channels.user_log ) @Cog.listener() @@ -433,7 +433,7 @@ class ModLog(Cog, name="ModLog"): Icons.user_unban, Colour.blurple(), "User unbanned", f"{member} (`{member.id}`)", thumbnail=member.avatar_url_as(static_format="png"), - channel_id=Channels.modlog + channel_id=Channels.mod_log ) @Cog.listener() @@ -529,7 +529,7 @@ class ModLog(Cog, name="ModLog"): Icons.user_update, Colour.blurple(), "Member updated", message, thumbnail=after.avatar_url_as(static_format="png"), - channel_id=Channels.userlog + channel_id=Channels.user_log ) @Cog.listener() @@ -538,7 +538,7 @@ class ModLog(Cog, name="ModLog"): channel = message.channel author = message.author - if message.guild.id != GuildConstant.id or channel.id in GuildConstant.ignored: + if message.guild.id != GuildConstant.id or channel.id in GuildConstant.modlog_blacklist: return self._cached_deletes.append(message.id) @@ -591,7 +591,7 @@ class ModLog(Cog, name="ModLog"): @Cog.listener() async def on_raw_message_delete(self, event: discord.RawMessageDeleteEvent) -> None: """Log raw message delete event to message change log.""" - if event.guild_id != GuildConstant.id or event.channel_id in GuildConstant.ignored: + if event.guild_id != GuildConstant.id or event.channel_id in GuildConstant.modlog_blacklist: return await asyncio.sleep(1) # Wait here in case the normal event was fired @@ -635,7 +635,7 @@ class ModLog(Cog, name="ModLog"): if ( not msg_before.guild or msg_before.guild.id != GuildConstant.id - or msg_before.channel.id in GuildConstant.ignored + or msg_before.channel.id in GuildConstant.modlog_blacklist or msg_before.author.bot ): return @@ -717,7 +717,7 @@ class ModLog(Cog, name="ModLog"): if ( not message.guild or message.guild.id != GuildConstant.id - or message.channel.id in GuildConstant.ignored + or message.channel.id in GuildConstant.modlog_blacklist or message.author.bot ): return @@ -769,7 +769,7 @@ class ModLog(Cog, name="ModLog"): """Log member voice state changes to the voice log channel.""" if ( member.guild.id != GuildConstant.id - or (before.channel and before.channel.id in GuildConstant.ignored) + or (before.channel and before.channel.id in GuildConstant.modlog_blacklist) ): return diff --git a/bot/cogs/moderation/scheduler.py b/bot/cogs/moderation/scheduler.py index e14c302cb..f0b6b2c48 100644 --- a/bot/cogs/moderation/scheduler.py +++ b/bot/cogs/moderation/scheduler.py @@ -1,3 +1,4 @@ +import asyncio import logging import textwrap import typing as t @@ -38,7 +39,7 @@ class InfractionScheduler(Scheduler): async def reschedule_infractions(self, supported_infractions: t.Container[str]) -> None: """Schedule expiration for previous infractions.""" - await self.bot.wait_until_ready() + await self.bot.wait_until_guild_available() log.trace(f"Rescheduling infractions for {self.__class__.__name__}.") @@ -48,7 +49,7 @@ class InfractionScheduler(Scheduler): ) for infraction in infractions: if infraction["expires_at"] is not None and infraction["type"] in supported_infractions: - self.schedule_task(self.bot.loop, infraction["id"], infraction) + self.schedule_task(infraction["id"], infraction) async def reapply_infraction( self, @@ -150,7 +151,7 @@ class InfractionScheduler(Scheduler): await action_coro if expiry: # Schedule the expiration of the infraction. - self.schedule_task(ctx.bot.loop, infraction["id"], infraction) + self.schedule_task(infraction["id"], infraction) except discord.HTTPException as e: # Accordingly display that applying the infraction failed. confirm_msg = f":x: failed to apply" @@ -307,18 +308,25 @@ class InfractionScheduler(Scheduler): Infractions of unsupported types will raise a ValueError. """ guild = self.bot.get_guild(constants.Guild.id) - mod_role = guild.get_role(constants.Roles.moderator) + mod_role = guild.get_role(constants.Roles.moderators) user_id = infraction["user"] + actor = infraction["actor"] type_ = infraction["type"] id_ = infraction["id"] + inserted_at = infraction["inserted_at"] + expiry = infraction["expires_at"] log.info(f"Marking infraction #{id_} as inactive (expired).") + expiry = dateutil.parser.isoparse(expiry).replace(tzinfo=None) if expiry else None + created = time.format_infraction_with_duration(inserted_at, expiry) + log_content = None log_text = { - "Member": str(user_id), - "Actor": str(self.bot.user), - "Reason": infraction["reason"] + "Member": f"<@{user_id}>", + "Actor": str(self.bot.get_user(actor) or actor), + "Reason": infraction["reason"], + "Created": created, } try: @@ -384,14 +392,19 @@ class InfractionScheduler(Scheduler): if send_log: log_title = f"expiration failed" if "Failure" in log_text else "expired" + user = self.bot.get_user(user_id) + avatar = user.avatar_url_as(static_format="png") if user else None + log.trace(f"Sending deactivation mod log for infraction #{id_}.") await self.mod_log.send_log_message( icon_url=utils.INFRACTION_ICONS[type_][1], colour=Colours.soft_green, title=f"Infraction {log_title}: {type_}", + thumbnail=avatar, text="\n".join(f"{k}: {v}" for k, v in log_text.items()), footer=f"ID: {id_}", content=log_content, + ) return log_text @@ -415,4 +428,6 @@ class InfractionScheduler(Scheduler): expiry = dateutil.parser.isoparse(infraction["expires_at"]).replace(tzinfo=None) await time.wait_until(expiry) - await self.deactivate_infraction(infraction) + # Because deactivate_infraction() explicitly cancels this scheduled task, it is shielded + # to avoid prematurely cancelling itself. + await asyncio.shield(self.deactivate_infraction(infraction)) diff --git a/bot/cogs/moderation/superstarify.py b/bot/cogs/moderation/superstarify.py index 050c847ac..893cb7f13 100644 --- a/bot/cogs/moderation/superstarify.py +++ b/bot/cogs/moderation/superstarify.py @@ -109,7 +109,8 @@ class Superstarify(InfractionScheduler, Cog): ctx: Context, member: Member, duration: Expiry, - reason: str = None + *, + reason: str = None, ) -> None: """ Temporarily force a random superstar name (like Taylor Swift) to be the user's nickname. @@ -145,7 +146,7 @@ class Superstarify(InfractionScheduler, Cog): log.debug(f"Changing nickname of {member} to {forced_nick}.") self.mod_log.ignore(constants.Event.member_update, member.id) await member.edit(nick=forced_nick, reason=reason) - self.schedule_task(ctx.bot.loop, id_, infraction) + self.schedule_task(id_, infraction) # Send a DM to the user to notify them of their new infraction. await utils.notify_infraction( diff --git a/bot/cogs/off_topic_names.py b/bot/cogs/off_topic_names.py index bf777ea5a..81511f99d 100644 --- a/bot/cogs/off_topic_names.py +++ b/bot/cogs/off_topic_names.py @@ -88,7 +88,7 @@ class OffTopicNames(Cog): async def init_offtopic_updater(self) -> None: """Start off-topic channel updating event loop if it hasn't already started.""" - await self.bot.wait_until_ready() + await self.bot.wait_until_guild_available() if self.updater_task is None: coro = update_names(self.bot) self.updater_task = self.bot.loop.create_task(coro) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index aa487f18e..5a7fa100f 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -43,12 +43,12 @@ class Reddit(Cog): def cog_unload(self) -> None: """Stop the loop task and revoke the access token when the cog is unloaded.""" self.auto_poster_loop.cancel() - if self.access_token.expires_at < datetime.utcnow(): - self.revoke_access_token() + if self.access_token and self.access_token.expires_at > datetime.utcnow(): + asyncio.create_task(self.revoke_access_token()) async def init_reddit_ready(self) -> None: """Sets the reddit webhook when the cog is loaded.""" - await self.bot.wait_until_ready() + await self.bot.wait_until_guild_available() if not self.webhook: self.webhook = await self.bot.fetch_webhook(Webhooks.reddit) @@ -83,7 +83,7 @@ class Reddit(Cog): expires_at=datetime.utcnow() + timedelta(seconds=expiration) ) - log.debug(f"New token acquired; expires on {self.access_token.expires_at}") + log.debug(f"New token acquired; expires on UTC {self.access_token.expires_at}") return else: log.debug( @@ -208,7 +208,7 @@ class Reddit(Cog): await asyncio.sleep(seconds_until) - await self.bot.wait_until_ready() + await self.bot.wait_until_guild_available() if not self.webhook: await self.bot.fetch_webhook(Webhooks.reddit) @@ -290,4 +290,7 @@ class Reddit(Cog): def setup(bot: Bot) -> None: """Load the Reddit cog.""" + if not RedditConfig.secret or not RedditConfig.client_id: + log.error("Credentials not provided, cog not loaded.") + return bot.add_cog(Reddit(bot)) diff --git a/bot/cogs/reminders.py b/bot/cogs/reminders.py index 45bf9a8f4..24c279357 100644 --- a/bot/cogs/reminders.py +++ b/bot/cogs/reminders.py @@ -2,16 +2,17 @@ import asyncio import logging import random import textwrap +import typing as t from datetime import datetime, timedelta from operator import itemgetter -from typing import Optional +import discord +from dateutil.parser import isoparse from dateutil.relativedelta import relativedelta -from discord import Colour, Embed, Message from discord.ext.commands import Cog, Context, group from bot.bot import Bot -from bot.constants import Channels, Icons, NEGATIVE_REPLIES, POSITIVE_REPLIES, STAFF_ROLES +from bot.constants import Guild, Icons, NEGATIVE_REPLIES, POSITIVE_REPLIES, STAFF_ROLES from bot.converters import Duration from bot.pagination import LinePaginator from bot.utils.checks import without_role_check @@ -20,7 +21,7 @@ from bot.utils.time import humanize_delta, wait_until log = logging.getLogger(__name__) -WHITELISTED_CHANNELS = (Channels.bot,) +WHITELISTED_CHANNELS = Guild.reminder_whitelist MAXIMUM_REMINDERS = 5 @@ -35,39 +36,73 @@ class Reminders(Scheduler, Cog): async def reschedule_reminders(self) -> None: """Get all current reminders from the API and reschedule them.""" - await self.bot.wait_until_ready() + await self.bot.wait_until_guild_available() response = await self.bot.api_client.get( 'bot/reminders', params={'active': 'true'} ) now = datetime.utcnow() - loop = asyncio.get_event_loop() for reminder in response: - remind_at = datetime.fromisoformat(reminder['expiration'][:-1]) + is_valid, *_ = self.ensure_valid_reminder(reminder, cancel_task=False) + if not is_valid: + continue + + remind_at = isoparse(reminder['expiration']).replace(tzinfo=None) # If the reminder is already overdue ... if remind_at < now: late = relativedelta(now, remind_at) await self.send_reminder(reminder, late) - else: - self.schedule_task(loop, reminder["id"], reminder) + self.schedule_task(reminder["id"], reminder) + + def ensure_valid_reminder( + self, + reminder: dict, + cancel_task: bool = True + ) -> t.Tuple[bool, discord.User, discord.TextChannel]: + """Ensure reminder author and channel can be fetched otherwise delete the reminder.""" + user = self.bot.get_user(reminder['author']) + channel = self.bot.get_channel(reminder['channel_id']) + is_valid = True + if not user or not channel: + is_valid = False + log.info( + f"Reminder {reminder['id']} invalid: " + f"User {reminder['author']}={user}, Channel {reminder['channel_id']}={channel}." + ) + asyncio.create_task(self._delete_reminder(reminder['id'], cancel_task)) + + return is_valid, user, channel @staticmethod - async def _send_confirmation(ctx: Context, on_success: str) -> None: + async def _send_confirmation( + ctx: Context, + on_success: str, + reminder_id: str, + delivery_dt: t.Optional[datetime], + ) -> None: """Send an embed confirming the reminder change was made successfully.""" - embed = Embed() - embed.colour = Colour.green() + embed = discord.Embed() + embed.colour = discord.Colour.green() embed.title = random.choice(POSITIVE_REPLIES) embed.description = on_success + + footer_str = f"ID: {reminder_id}" + if delivery_dt: + # Reminder deletion will have a `None` `delivery_dt` + footer_str = f"{footer_str}, Due: {delivery_dt.strftime('%Y-%m-%dT%H:%M:%S')}" + + embed.set_footer(text=footer_str) + await ctx.send(embed=embed) async def _scheduled_task(self, reminder: dict) -> None: """A coroutine which sends the reminder once the time is reached, and cancels the running task.""" reminder_id = reminder["id"] - reminder_datetime = datetime.fromisoformat(reminder['expiration'][:-1]) + reminder_datetime = isoparse(reminder['expiration']).replace(tzinfo=None) # Send the reminder message once the desired duration has passed await wait_until(reminder_datetime) @@ -76,30 +111,30 @@ class Reminders(Scheduler, Cog): log.debug(f"Deleting reminder {reminder_id} (the user has been reminded).") await self._delete_reminder(reminder_id) - # Now we can begone with it from our schedule list. - self.cancel_task(reminder_id) - - async def _delete_reminder(self, reminder_id: str) -> None: + async def _delete_reminder(self, reminder_id: str, cancel_task: bool = True) -> None: """Delete a reminder from the database, given its ID, and cancel the running task.""" await self.bot.api_client.delete('bot/reminders/' + str(reminder_id)) - # Now we can remove it from the schedule list - self.cancel_task(reminder_id) + if cancel_task: + # Now we can remove it from the schedule list + self.cancel_task(reminder_id) async def _reschedule_reminder(self, reminder: dict) -> None: """Reschedule a reminder object.""" - loop = asyncio.get_event_loop() - + log.trace(f"Cancelling old task #{reminder['id']}") self.cancel_task(reminder["id"]) - self.schedule_task(loop, reminder["id"], reminder) + + log.trace(f"Scheduling new task #{reminder['id']}") + self.schedule_task(reminder["id"], reminder) async def send_reminder(self, reminder: dict, late: relativedelta = None) -> None: """Send the reminder.""" - channel = self.bot.get_channel(reminder["channel_id"]) - user = self.bot.get_user(reminder["author"]) + is_valid, user, channel = self.ensure_valid_reminder(reminder) + if not is_valid: + return - embed = Embed() - embed.colour = Colour.blurple() + embed = discord.Embed() + embed.colour = discord.Colour.blurple() embed.set_author( icon_url=Icons.remind_blurple, name="It has arrived!" @@ -111,7 +146,7 @@ class Reminders(Scheduler, Cog): embed.description += f"\n[Jump back to when you created the reminder]({reminder['jump_url']})" if late: - embed.colour = Colour.red() + embed.colour = discord.Colour.red() embed.set_author( icon_url=Icons.remind_red, name=f"Sorry it arrived {humanize_delta(late, max_units=2)} late!" @@ -129,20 +164,20 @@ class Reminders(Scheduler, Cog): await ctx.invoke(self.new_reminder, expiration=expiration, content=content) @remind_group.command(name="new", aliases=("add", "create")) - async def new_reminder(self, ctx: Context, expiration: Duration, *, content: str) -> Optional[Message]: + async def new_reminder(self, ctx: Context, expiration: Duration, *, content: str) -> t.Optional[discord.Message]: """ Set yourself a simple reminder. Expiration is parsed per: http://strftime.org/ """ - embed = Embed() + embed = discord.Embed() # If the user is not staff, we need to verify whether or not to make a reminder at all. if without_role_check(ctx, *STAFF_ROLES): # If they don't have permission to set a reminder in this channel if ctx.channel.id not in WHITELISTED_CHANNELS: - embed.colour = Colour.red() + embed.colour = discord.Colour.red() embed.title = random.choice(NEGATIVE_REPLIES) embed.description = "Sorry, you can't do that here!" @@ -159,7 +194,7 @@ class Reminders(Scheduler, Cog): # Let's limit this, so we don't get 10 000 # reminders from kip or something like that :P if len(active_reminders) > MAXIMUM_REMINDERS: - embed.colour = Colour.red() + embed.colour = discord.Colour.red() embed.title = random.choice(NEGATIVE_REPLIES) embed.description = "You have too many active reminders!" @@ -178,18 +213,20 @@ class Reminders(Scheduler, Cog): ) now = datetime.utcnow() - timedelta(seconds=1) + humanized_delta = humanize_delta(relativedelta(expiration, now)) # Confirm to the user that it worked. await self._send_confirmation( ctx, - on_success=f"Your reminder will arrive in {humanize_delta(relativedelta(expiration, now))}!" + on_success=f"Your reminder will arrive in {humanized_delta}!", + reminder_id=reminder["id"], + delivery_dt=expiration, ) - loop = asyncio.get_event_loop() - self.schedule_task(loop, reminder["id"], reminder) + self.schedule_task(reminder["id"], reminder) @remind_group.command(name="list") - async def list_reminders(self, ctx: Context) -> Optional[Message]: + async def list_reminders(self, ctx: Context) -> t.Optional[discord.Message]: """View a paginated embed of all reminders for your user.""" # Get all the user's reminders from the database. data = await self.bot.api_client.get( @@ -212,7 +249,7 @@ class Reminders(Scheduler, Cog): for content, remind_at, id_ in reminders: # Parse and humanize the time, make it pretty :D - remind_datetime = datetime.fromisoformat(remind_at[:-1]) + remind_datetime = isoparse(remind_at).replace(tzinfo=None) time = humanize_delta(relativedelta(remind_datetime, now)) text = textwrap.dedent(f""" @@ -222,8 +259,8 @@ class Reminders(Scheduler, Cog): lines.append(text) - embed = Embed() - embed.colour = Colour.blurple() + embed = discord.Embed() + embed.colour = discord.Colour.blurple() embed.title = f"Reminders for {ctx.author}" # Remind the user that they have no reminders :^) @@ -232,7 +269,7 @@ class Reminders(Scheduler, Cog): return await ctx.send(embed=embed) # Construct the embed and paginate it. - embed.colour = Colour.blurple() + embed.colour = discord.Colour.blurple() await LinePaginator.paginate( lines, @@ -261,7 +298,10 @@ class Reminders(Scheduler, Cog): # Send a confirmation message to the channel await self._send_confirmation( - ctx, on_success="That reminder has been edited successfully!" + ctx, + on_success="That reminder has been edited successfully!", + reminder_id=id_, + delivery_dt=expiration, ) await self._reschedule_reminder(reminder) @@ -275,18 +315,27 @@ class Reminders(Scheduler, Cog): json={'content': content} ) + # Parse the reminder expiration back into a datetime for the confirmation message + expiration = isoparse(reminder['expiration']).replace(tzinfo=None) + # Send a confirmation message to the channel await self._send_confirmation( - ctx, on_success="That reminder has been edited successfully!" + ctx, + on_success="That reminder has been edited successfully!", + reminder_id=id_, + delivery_dt=expiration, ) await self._reschedule_reminder(reminder) - @remind_group.command("delete", aliases=("remove",)) + @remind_group.command("delete", aliases=("remove", "cancel")) async def delete_reminder(self, ctx: Context, id_: int) -> None: """Delete one of your active reminders.""" await self._delete_reminder(id_) await self._send_confirmation( - ctx, on_success="That reminder has been deleted successfully!" + ctx, + on_success="That reminder has been deleted successfully!", + reminder_id=id_, + delivery_dt=None, ) diff --git a/bot/cogs/snekbox.py b/bot/cogs/snekbox.py index da33e27b2..cff7c5786 100644 --- a/bot/cogs/snekbox.py +++ b/bot/cogs/snekbox.py @@ -1,10 +1,14 @@ +import asyncio +import contextlib import datetime import logging import re import textwrap +from functools import partial from signal import Signals from typing import Optional, Tuple +from discord import HTTPException, Message, NotFound, Reaction, User from discord.ext.commands import Cog, Context, command, guild_only from bot.bot import Bot @@ -34,7 +38,11 @@ RAW_CODE_REGEX = re.compile( ) MAX_PASTE_LEN = 1000 -EVAL_ROLES = (Roles.helpers, Roles.moderator, Roles.admin, Roles.owner, Roles.rockstars, Roles.partners) +EVAL_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners) + +SIGKILL = 9 + +REEVAL_EMOJI = '\U0001f501' # :repeat: class Snekbox(Cog): @@ -101,7 +109,7 @@ class Snekbox(Cog): if returncode is None: msg = "Your eval job has failed" error = stdout.strip() - elif returncode == 128 + Signals.SIGKILL: + elif returncode == 128 + SIGKILL: msg = "Your eval job timed out or ran out of memory" elif returncode == 255: msg = "Your eval job has failed" @@ -135,7 +143,7 @@ class Snekbox(Cog): """ log.trace("Formatting output...") - output = output.strip(" \n") + output = output.rstrip("\n") original_output = output # To be uploaded to a pasting service if needed paste_link = None @@ -152,8 +160,8 @@ class Snekbox(Cog): lines = output.count("\n") if lines > 0: - output = output.split("\n")[:10] # Only first 10 cause the rest is truncated anyway - output = (f"{i:03d} | {line}" for i, line in enumerate(output, 1)) + output = [f"{i:03d} | {line}" for i, line in enumerate(output.split('\n'), 1)] + output = output[:11] # Limiting to only 11 lines output = "\n".join(output) if lines > 10: @@ -169,21 +177,84 @@ class Snekbox(Cog): if truncated: paste_link = await self.upload_output(original_output) - output = output.strip() - if not output: - output = "[No output]" + output = output or "[No output]" return output, paste_link + async def send_eval(self, ctx: Context, code: str) -> Message: + """ + Evaluate code, format it, and send the output to the corresponding channel. + + Return the bot response. + """ + async with ctx.typing(): + results = await self.post_eval(code) + msg, error = self.get_results_message(results) + + if error: + output, paste_link = error, None + else: + output, paste_link = await self.format_output(results["stdout"]) + + icon = self.get_status_emoji(results) + msg = f"{ctx.author.mention} {icon} {msg}.\n\n```py\n{output}\n```" + if paste_link: + msg = f"{msg}\nFull output: {paste_link}" + + response = await ctx.send(msg) + self.bot.loop.create_task( + wait_for_deletion(response, user_ids=(ctx.author.id,), client=ctx.bot) + ) + + log.info(f"{ctx.author}'s job had a return code of {results['returncode']}") + return response + + async def continue_eval(self, ctx: Context, response: Message) -> Optional[str]: + """ + Check if the eval session should continue. + + Return the new code to evaluate or None if the eval session should be terminated. + """ + _predicate_eval_message_edit = partial(predicate_eval_message_edit, ctx) + _predicate_emoji_reaction = partial(predicate_eval_emoji_reaction, ctx) + + with contextlib.suppress(NotFound): + try: + _, new_message = await self.bot.wait_for( + 'message_edit', + check=_predicate_eval_message_edit, + timeout=10 + ) + await ctx.message.add_reaction(REEVAL_EMOJI) + await self.bot.wait_for( + 'reaction_add', + check=_predicate_emoji_reaction, + timeout=10 + ) + + code = new_message.content.split(' ', maxsplit=1)[1] + await ctx.message.clear_reactions() + with contextlib.suppress(HTTPException): + await response.delete() + + except asyncio.TimeoutError: + await ctx.message.clear_reactions() + return None + + return code + @command(name="eval", aliases=("e",)) @guild_only() - @in_channel(Channels.bot, hidden_channels=(Channels.esoteric,), bypass_roles=EVAL_ROLES) + @in_channel(Channels.bot_commands, hidden_channels=(Channels.esoteric,), bypass_roles=EVAL_ROLES) async def eval_command(self, ctx: Context, *, code: str = None) -> None: """ Run Python code and get the results. This command supports multiple lines of code, including code wrapped inside a formatted code - block. We've done our best to make this safe, but do let us know if you manage to find an + block. Code can be re-evaluated by editing the original message within 10 seconds and + clicking the reaction that subsequently appears. + + We've done our best to make this sandboxed, but do let us know if you manage to find an issue with it! """ if ctx.author.id in self.jobs: @@ -199,32 +270,28 @@ class Snekbox(Cog): log.info(f"Received code from {ctx.author} for evaluation:\n{code}") - self.jobs[ctx.author.id] = datetime.datetime.now() - code = self.prepare_input(code) + while True: + self.jobs[ctx.author.id] = datetime.datetime.now() + code = self.prepare_input(code) + try: + response = await self.send_eval(ctx, code) + finally: + del self.jobs[ctx.author.id] + + code = await self.continue_eval(ctx, response) + if not code: + break + log.info(f"Re-evaluating message {ctx.message.id}") + + +def predicate_eval_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> bool: + """Return True if the edited message is the context message and the content was indeed modified.""" + return new_msg.id == ctx.message.id and old_msg.content != new_msg.content - try: - async with ctx.typing(): - results = await self.post_eval(code) - msg, error = self.get_results_message(results) - - if error: - output, paste_link = error, None - else: - output, paste_link = await self.format_output(results["stdout"]) - - icon = self.get_status_emoji(results) - msg = f"{ctx.author.mention} {icon} {msg}.\n\n```py\n{output}\n```" - if paste_link: - msg = f"{msg}\nFull output: {paste_link}" - - response = await ctx.send(msg) - self.bot.loop.create_task( - wait_for_deletion(response, user_ids=(ctx.author.id,), client=ctx.bot) - ) - log.info(f"{ctx.author}'s job had a return code of {results['returncode']}") - finally: - del self.jobs[ctx.author.id] +def predicate_eval_emoji_reaction(ctx: Context, reaction: Reaction, user: User) -> bool: + """Return True if the reaction REEVAL_EMOJI was added by the context message author on this message.""" + return reaction.message.id == ctx.message.id and user.id == ctx.author.id and str(reaction) == REEVAL_EMOJI def setup(bot: Bot) -> None: diff --git a/bot/cogs/sync/cog.py b/bot/cogs/sync/cog.py index 4e6ed156b..5708be3f4 100644 --- a/bot/cogs/sync/cog.py +++ b/bot/cogs/sync/cog.py @@ -1,7 +1,7 @@ import logging -from typing import Callable, Dict, Iterable, Union +from typing import Any, Dict -from discord import Guild, Member, Role, User +from discord import Member, Role, User from discord.ext import commands from discord.ext.commands import Cog, Context @@ -16,45 +16,28 @@ log = logging.getLogger(__name__) class Sync(Cog): """Captures relevant events and sends them to the site.""" - # The server to synchronize events on. - # Note that setting this wrongly will result in things getting deleted - # that possibly shouldn't be. - SYNC_SERVER_ID = constants.Guild.id - - # An iterable of callables that are called when the bot is ready. - ON_READY_SYNCERS: Iterable[Callable[[Bot, Guild], None]] = ( - syncers.sync_roles, - syncers.sync_users - ) - def __init__(self, bot: Bot) -> None: self.bot = bot + self.role_syncer = syncers.RoleSyncer(self.bot) + self.user_syncer = syncers.UserSyncer(self.bot) self.bot.loop.create_task(self.sync_guild()) async def sync_guild(self) -> None: """Syncs the roles/users of the guild with the database.""" - await self.bot.wait_until_ready() - guild = self.bot.get_guild(self.SYNC_SERVER_ID) - if guild is not None: - for syncer in self.ON_READY_SYNCERS: - syncer_name = syncer.__name__[5:] # drop off `sync_` - log.info("Starting `%s` syncer.", syncer_name) - total_created, total_updated, total_deleted = await syncer(self.bot, guild) - if total_deleted is None: - log.info( - f"`{syncer_name}` syncer finished, created `{total_created}`, updated `{total_updated}`." - ) - else: - log.info( - f"`{syncer_name}` syncer finished, created `{total_created}`, updated `{total_updated}`, " - f"deleted `{total_deleted}`." - ) - - async def patch_user(self, user_id: int, updated_information: Dict[str, Union[str, int]]) -> None: + await self.bot.wait_until_guild_available() + + guild = self.bot.get_guild(constants.Guild.id) + if guild is None: + return + + for syncer in (self.role_syncer, self.user_syncer): + await syncer.sync(guild) + + async def patch_user(self, user_id: int, updated_information: Dict[str, Any]) -> None: """Send a PATCH request to partially update a user in the database.""" try: - await self.bot.api_client.patch("bot/users/" + str(user_id), json=updated_information) + await self.bot.api_client.patch(f"bot/users/{user_id}", json=updated_information) except ResponseCodeError as e: if e.response.status != 404: raise @@ -82,12 +65,14 @@ class Sync(Cog): @Cog.listener() async def on_guild_role_update(self, before: Role, after: Role) -> None: """Syncs role with the database if any of the stored attributes were updated.""" - if ( - before.name != after.name - or before.colour != after.colour - or before.permissions != after.permissions - or before.position != after.position - ): + was_updated = ( + before.name != after.name + or before.colour != after.colour + or before.permissions != after.permissions + or before.position != after.position + ) + + if was_updated: await self.bot.api_client.put( f'bot/roles/{after.id}', json={ @@ -137,18 +122,8 @@ class Sync(Cog): @Cog.listener() async def on_member_remove(self, member: Member) -> None: - """Updates the user information when a member leaves the guild.""" - await self.bot.api_client.put( - f'bot/users/{member.id}', - json={ - 'avatar_hash': member.avatar, - 'discriminator': int(member.discriminator), - 'id': member.id, - 'in_guild': False, - 'name': member.name, - 'roles': sorted(role.id for role in member.roles) - } - ) + """Set the in_guild field to False when a member leaves the guild.""" + await self.patch_user(member.id, updated_information={"in_guild": False}) @Cog.listener() async def on_member_update(self, before: Member, after: Member) -> None: @@ -160,7 +135,8 @@ class Sync(Cog): @Cog.listener() async def on_user_update(self, before: User, after: User) -> None: """Update the user information in the database if a relevant change is detected.""" - if any(getattr(before, attr) != getattr(after, attr) for attr in ("name", "discriminator", "avatar")): + attrs = ("name", "discriminator", "avatar") + if any(getattr(before, attr) != getattr(after, attr) for attr in attrs): updated_information = { "name": after.name, "discriminator": int(after.discriminator), @@ -176,25 +152,11 @@ class Sync(Cog): @sync_group.command(name='roles') @commands.has_permissions(administrator=True) async def sync_roles_command(self, ctx: Context) -> None: - """Manually synchronize the guild's roles with the roles on the site.""" - initial_response = await ctx.send("📊 Synchronizing roles.") - total_created, total_updated, total_deleted = await syncers.sync_roles(self.bot, ctx.guild) - await initial_response.edit( - content=( - f"👌 Role synchronization complete, created **{total_created}** " - f", updated **{total_created}** roles, and deleted **{total_deleted}** roles." - ) - ) + """Manually synchronise the guild's roles with the roles on the site.""" + await self.role_syncer.sync(ctx.guild, ctx) @sync_group.command(name='users') @commands.has_permissions(administrator=True) async def sync_users_command(self, ctx: Context) -> None: - """Manually synchronize the guild's users with the users on the site.""" - initial_response = await ctx.send("📊 Synchronizing users.") - total_created, total_updated, total_deleted = await syncers.sync_users(self.bot, ctx.guild) - await initial_response.edit( - content=( - f"👌 User synchronization complete, created **{total_created}** " - f"and updated **{total_created}** users." - ) - ) + """Manually synchronise the guild's users with the users on the site.""" + await self.user_syncer.sync(ctx.guild, ctx) diff --git a/bot/cogs/sync/syncers.py b/bot/cogs/sync/syncers.py index 14cf51383..c7ce54d65 100644 --- a/bot/cogs/sync/syncers.py +++ b/bot/cogs/sync/syncers.py @@ -1,235 +1,342 @@ +import abc +import logging +import typing as t from collections import namedtuple -from typing import Dict, Set, Tuple +from functools import partial -from discord import Guild +from discord import Guild, HTTPException, Member, Message, Reaction, User +from discord.ext.commands import Context +from bot import constants +from bot.api import ResponseCodeError from bot.bot import Bot +log = logging.getLogger(__name__) + # These objects are declared as namedtuples because tuples are hashable, # something that we make use of when diffing site roles against guild roles. -Role = namedtuple('Role', ('id', 'name', 'colour', 'permissions', 'position')) -User = namedtuple('User', ('id', 'name', 'discriminator', 'avatar_hash', 'roles', 'in_guild')) - - -def get_roles_for_sync( - guild_roles: Set[Role], api_roles: Set[Role] -) -> Tuple[Set[Role], Set[Role], Set[Role]]: - """ - Determine which roles should be created or updated on the site. - - Arguments: - guild_roles (Set[Role]): - Roles that were found on the guild at startup. - - api_roles (Set[Role]): - Roles that were retrieved from the API at startup. - - Returns: - Tuple[Set[Role], Set[Role]. Set[Role]]: - A tuple with three elements. The first element represents - roles to be created on the site, meaning that they were - present on the cached guild but not on the API. The second - element represents roles to be updated, meaning they were - present on both the cached guild and the API but non-ID - fields have changed inbetween. The third represents roles - to be deleted on the site, meaning the roles are present on - the API but not in the cached guild. - """ - guild_role_ids = {role.id for role in guild_roles} - api_role_ids = {role.id for role in api_roles} - new_role_ids = guild_role_ids - api_role_ids - deleted_role_ids = api_role_ids - guild_role_ids - - # New roles are those which are on the cached guild but not on the - # API guild, going by the role ID. We need to send them in for creation. - roles_to_create = {role for role in guild_roles if role.id in new_role_ids} - roles_to_update = guild_roles - api_roles - roles_to_create - roles_to_delete = {role for role in api_roles if role.id in deleted_role_ids} - return roles_to_create, roles_to_update, roles_to_delete - - -async def sync_roles(bot: Bot, guild: Guild) -> Tuple[int, int, int]: - """ - Synchronize roles found on the given `guild` with the ones on the API. - - Arguments: - bot (bot.bot.Bot): - The bot instance that we're running with. - - guild (discord.Guild): - The guild instance from the bot's cache - to synchronize roles with. - - Returns: - Tuple[int, int, int]: - A tuple with three integers representing how many roles were created - (element `0`) , how many roles were updated (element `1`), and how many - roles were deleted (element `2`) on the API. - """ - roles = await bot.api_client.get('bot/roles') - - # Pack API roles and guild roles into one common format, - # which is also hashable. We need hashability to be able - # to compare these easily later using sets. - api_roles = {Role(**role_dict) for role_dict in roles} - guild_roles = { - Role( - id=role.id, name=role.name, - colour=role.colour.value, permissions=role.permissions.value, - position=role.position, - ) - for role in guild.roles - } - roles_to_create, roles_to_update, roles_to_delete = get_roles_for_sync(guild_roles, api_roles) - - for role in roles_to_create: - await bot.api_client.post( - 'bot/roles', - json={ - 'id': role.id, - 'name': role.name, - 'colour': role.colour, - 'permissions': role.permissions, - 'position': role.position, - } - ) +_Role = namedtuple('Role', ('id', 'name', 'colour', 'permissions', 'position')) +_User = namedtuple('User', ('id', 'name', 'discriminator', 'avatar_hash', 'roles', 'in_guild')) +_Diff = namedtuple('Diff', ('created', 'updated', 'deleted')) - for role in roles_to_update: - await bot.api_client.put( - f'bot/roles/{role.id}', - json={ - 'id': role.id, - 'name': role.name, - 'colour': role.colour, - 'permissions': role.permissions, - 'position': role.position, - } - ) - for role in roles_to_delete: - await bot.api_client.delete(f'bot/roles/{role.id}') - - return len(roles_to_create), len(roles_to_update), len(roles_to_delete) - - -def get_users_for_sync( - guild_users: Dict[int, User], api_users: Dict[int, User] -) -> Tuple[Set[User], Set[User]]: - """ - Determine which users should be created or updated on the website. - - Arguments: - guild_users (Dict[int, User]): - A mapping of user IDs to user data, populated from the - guild cached on the running bot instance. - - api_users (Dict[int, User]): - A mapping of user IDs to user data, populated from the API's - current inventory of all users. - - Returns: - Tuple[Set[User], Set[User]]: - Two user sets as a tuple. The first element represents users - to be created on the website, these are users that are present - in the cached guild data but not in the API at all, going by - their ID. The second element represents users to update. It is - populated by users which are present on both the API and the - guild, but where the attribute of a user on the API is not - equal to the attribute of the user on the guild. - """ - users_to_create = set() - users_to_update = set() - - for api_user in api_users.values(): - guild_user = guild_users.get(api_user.id) - if guild_user is not None: - if api_user != guild_user: - users_to_update.add(guild_user) - - elif api_user.in_guild: - # The user is known on the API but not the guild, and the - # API currently specifies that the user is a member of the guild. - # This means that the user has left since the last sync. - # Update the `in_guild` attribute of the user on the site - # to signify that the user left. - new_api_user = api_user._replace(in_guild=False) - users_to_update.add(new_api_user) - - new_user_ids = set(guild_users.keys()) - set(api_users.keys()) - for user_id in new_user_ids: - # The user is known on the guild but not on the API. This means - # that the user has joined since the last sync. Create it. - new_user = guild_users[user_id] - users_to_create.add(new_user) - - return users_to_create, users_to_update - - -async def sync_users(bot: Bot, guild: Guild) -> Tuple[int, int, None]: - """ - Synchronize users found in the given `guild` with the ones in the API. - - Arguments: - bot (bot.bot.Bot): - The bot instance that we're running with. - - guild (discord.Guild): - The guild instance from the bot's cache - to synchronize roles with. - - Returns: - Tuple[int, int, None]: - A tuple with two integers, representing how many users were created - (element `0`) and how many users were updated (element `1`), and `None` - to indicate that a user sync never deletes entries from the API. - """ - current_users = await bot.api_client.get('bot/users') - - # Pack API users and guild users into one common format, - # which is also hashable. We need hashability to be able - # to compare these easily later using sets. - api_users = { - user_dict['id']: User( - roles=tuple(sorted(user_dict.pop('roles'))), - **user_dict - ) - for user_dict in current_users - } - guild_users = { - member.id: User( - id=member.id, name=member.name, - discriminator=int(member.discriminator), avatar_hash=member.avatar, - roles=tuple(sorted(role.id for role in member.roles)), in_guild=True - ) - for member in guild.members - } - - users_to_create, users_to_update = get_users_for_sync(guild_users, api_users) - - for user in users_to_create: - await bot.api_client.post( - 'bot/users', - json={ - 'avatar_hash': user.avatar_hash, - 'discriminator': user.discriminator, - 'id': user.id, - 'in_guild': user.in_guild, - 'name': user.name, - 'roles': list(user.roles) - } +class Syncer(abc.ABC): + """Base class for synchronising the database with objects in the Discord cache.""" + + _CORE_DEV_MENTION = f"<@&{constants.Roles.core_developers}> " + _REACTION_EMOJIS = (constants.Emojis.check_mark, constants.Emojis.cross_mark) + + def __init__(self, bot: Bot) -> None: + self.bot = bot + + @property + @abc.abstractmethod + def name(self) -> str: + """The name of the syncer; used in output messages and logging.""" + raise NotImplementedError # pragma: no cover + + async def _send_prompt(self, message: t.Optional[Message] = None) -> t.Optional[Message]: + """ + Send a prompt to confirm or abort a sync using reactions and return the sent message. + + If a message is given, it is edited to display the prompt and reactions. Otherwise, a new + message is sent to the dev-core channel and mentions the core developers role. If the + channel cannot be retrieved, return None. + """ + log.trace(f"Sending {self.name} sync confirmation prompt.") + + msg_content = ( + f'Possible cache issue while syncing {self.name}s. ' + f'More than {constants.Sync.max_diff} {self.name}s were changed. ' + f'React to confirm or abort the sync.' ) - for user in users_to_update: - await bot.api_client.put( - f'bot/users/{user.id}', - json={ - 'avatar_hash': user.avatar_hash, - 'discriminator': user.discriminator, - 'id': user.id, - 'in_guild': user.in_guild, - 'name': user.name, - 'roles': list(user.roles) - } + # Send to core developers if it's an automatic sync. + if not message: + log.trace("Message not provided for confirmation; creating a new one in dev-core.") + channel = self.bot.get_channel(constants.Channels.dev_core) + + if not channel: + log.debug("Failed to get the dev-core channel from cache; attempting to fetch it.") + try: + channel = await self.bot.fetch_channel(constants.Channels.dev_core) + except HTTPException: + log.exception( + f"Failed to fetch channel for sending sync confirmation prompt; " + f"aborting {self.name} sync." + ) + return None + + message = await channel.send(f"{self._CORE_DEV_MENTION}{msg_content}") + else: + await message.edit(content=msg_content) + + # Add the initial reactions. + log.trace(f"Adding reactions to {self.name} syncer confirmation prompt.") + for emoji in self._REACTION_EMOJIS: + await message.add_reaction(emoji) + + return message + + def _reaction_check( + self, + author: Member, + message: Message, + reaction: Reaction, + user: t.Union[Member, User] + ) -> bool: + """ + Return True if the `reaction` is a valid confirmation or abort reaction on `message`. + + If the `author` of the prompt is a bot, then a reaction by any core developer will be + considered valid. Otherwise, the author of the reaction (`user`) will have to be the + `author` of the prompt. + """ + # For automatic syncs, check for the core dev role instead of an exact author + has_role = any(constants.Roles.core_developers == role.id for role in user.roles) + return ( + reaction.message.id == message.id + and not user.bot + and (has_role if author.bot else user == author) + and str(reaction.emoji) in self._REACTION_EMOJIS ) - return len(users_to_create), len(users_to_update), None + async def _wait_for_confirmation(self, author: Member, message: Message) -> bool: + """ + Wait for a confirmation reaction by `author` on `message` and return True if confirmed. + + Uses the `_reaction_check` function to determine if a reaction is valid. + + If there is no reaction within `bot.constants.Sync.confirm_timeout` seconds, return False. + To acknowledge the reaction (or lack thereof), `message` will be edited. + """ + # Preserve the core-dev role mention in the message edits so users aren't confused about + # where notifications came from. + mention = self._CORE_DEV_MENTION if author.bot else "" + + reaction = None + try: + log.trace(f"Waiting for a reaction to the {self.name} syncer confirmation prompt.") + reaction, _ = await self.bot.wait_for( + 'reaction_add', + check=partial(self._reaction_check, author, message), + timeout=constants.Sync.confirm_timeout + ) + except TimeoutError: + # reaction will remain none thus sync will be aborted in the finally block below. + log.debug(f"The {self.name} syncer confirmation prompt timed out.") + + if str(reaction) == constants.Emojis.check_mark: + log.trace(f"The {self.name} syncer was confirmed.") + await message.edit(content=f':ok_hand: {mention}{self.name} sync will proceed.') + return True + else: + log.warning(f"The {self.name} syncer was aborted or timed out!") + await message.edit( + content=f':warning: {mention}{self.name} sync aborted or timed out!' + ) + return False + + @abc.abstractmethod + async def _get_diff(self, guild: Guild) -> _Diff: + """Return the difference between the cache of `guild` and the database.""" + raise NotImplementedError # pragma: no cover + + @abc.abstractmethod + async def _sync(self, diff: _Diff) -> None: + """Perform the API calls for synchronisation.""" + raise NotImplementedError # pragma: no cover + + async def _get_confirmation_result( + self, + diff_size: int, + author: Member, + message: t.Optional[Message] = None + ) -> t.Tuple[bool, t.Optional[Message]]: + """ + Prompt for confirmation and return a tuple of the result and the prompt message. + + `diff_size` is the size of the diff of the sync. If it is greater than + `bot.constants.Sync.max_diff`, the prompt will be sent. The `author` is the invoked of the + sync and the `message` is an extant message to edit to display the prompt. + + If confirmed or no confirmation was needed, the result is True. The returned message will + either be the given `message` or a new one which was created when sending the prompt. + """ + log.trace(f"Determining if confirmation prompt should be sent for {self.name} syncer.") + if diff_size > constants.Sync.max_diff: + message = await self._send_prompt(message) + if not message: + return False, None # Couldn't get channel. + + confirmed = await self._wait_for_confirmation(author, message) + if not confirmed: + return False, message # Sync aborted. + + return True, message + + async def sync(self, guild: Guild, ctx: t.Optional[Context] = None) -> None: + """ + Synchronise the database with the cache of `guild`. + + If the differences between the cache and the database are greater than + `bot.constants.Sync.max_diff`, then a confirmation prompt will be sent to the dev-core + channel. The confirmation can be optionally redirect to `ctx` instead. + """ + log.info(f"Starting {self.name} syncer.") + + message = None + author = self.bot.user + if ctx: + message = await ctx.send(f"📊 Synchronising {self.name}s.") + author = ctx.author + + diff = await self._get_diff(guild) + diff_dict = diff._asdict() # Ugly method for transforming the NamedTuple into a dict + totals = {k: len(v) for k, v in diff_dict.items() if v is not None} + diff_size = sum(totals.values()) + + confirmed, message = await self._get_confirmation_result(diff_size, author, message) + if not confirmed: + return + + # Preserve the core-dev role mention in the message edits so users aren't confused about + # where notifications came from. + mention = self._CORE_DEV_MENTION if author.bot else "" + + try: + await self._sync(diff) + except ResponseCodeError as e: + log.exception(f"{self.name} syncer failed!") + + # Don't show response text because it's probably some really long HTML. + results = f"status {e.status}\n```{e.response_json or 'See log output for details'}```" + content = f":x: {mention}Synchronisation of {self.name}s failed: {results}" + else: + results = ", ".join(f"{name} `{total}`" for name, total in totals.items()) + log.info(f"{self.name} syncer finished: {results}.") + content = f":ok_hand: {mention}Synchronisation of {self.name}s complete: {results}" + + if message: + await message.edit(content=content) + + +class RoleSyncer(Syncer): + """Synchronise the database with roles in the cache.""" + + name = "role" + + async def _get_diff(self, guild: Guild) -> _Diff: + """Return the difference of roles between the cache of `guild` and the database.""" + log.trace("Getting the diff for roles.") + roles = await self.bot.api_client.get('bot/roles') + + # Pack DB roles and guild roles into one common, hashable format. + # They're hashable so that they're easily comparable with sets later. + db_roles = {_Role(**role_dict) for role_dict in roles} + guild_roles = { + _Role( + id=role.id, + name=role.name, + colour=role.colour.value, + permissions=role.permissions.value, + position=role.position, + ) + for role in guild.roles + } + + guild_role_ids = {role.id for role in guild_roles} + api_role_ids = {role.id for role in db_roles} + new_role_ids = guild_role_ids - api_role_ids + deleted_role_ids = api_role_ids - guild_role_ids + + # New roles are those which are on the cached guild but not on the + # DB guild, going by the role ID. We need to send them in for creation. + roles_to_create = {role for role in guild_roles if role.id in new_role_ids} + roles_to_update = guild_roles - db_roles - roles_to_create + roles_to_delete = {role for role in db_roles if role.id in deleted_role_ids} + + return _Diff(roles_to_create, roles_to_update, roles_to_delete) + + async def _sync(self, diff: _Diff) -> None: + """Synchronise the database with the role cache of `guild`.""" + log.trace("Syncing created roles...") + for role in diff.created: + await self.bot.api_client.post('bot/roles', json=role._asdict()) + + log.trace("Syncing updated roles...") + for role in diff.updated: + await self.bot.api_client.put(f'bot/roles/{role.id}', json=role._asdict()) + + log.trace("Syncing deleted roles...") + for role in diff.deleted: + await self.bot.api_client.delete(f'bot/roles/{role.id}') + + +class UserSyncer(Syncer): + """Synchronise the database with users in the cache.""" + + name = "user" + + async def _get_diff(self, guild: Guild) -> _Diff: + """Return the difference of users between the cache of `guild` and the database.""" + log.trace("Getting the diff for users.") + users = await self.bot.api_client.get('bot/users') + + # Pack DB roles and guild roles into one common, hashable format. + # They're hashable so that they're easily comparable with sets later. + db_users = { + user_dict['id']: _User( + roles=tuple(sorted(user_dict.pop('roles'))), + **user_dict + ) + for user_dict in users + } + guild_users = { + member.id: _User( + id=member.id, + name=member.name, + discriminator=int(member.discriminator), + avatar_hash=member.avatar, + roles=tuple(sorted(role.id for role in member.roles)), + in_guild=True + ) + for member in guild.members + } + + users_to_create = set() + users_to_update = set() + + for db_user in db_users.values(): + guild_user = guild_users.get(db_user.id) + if guild_user is not None: + if db_user != guild_user: + users_to_update.add(guild_user) + + elif db_user.in_guild: + # The user is known in the DB but not the guild, and the + # DB currently specifies that the user is a member of the guild. + # This means that the user has left since the last sync. + # Update the `in_guild` attribute of the user on the site + # to signify that the user left. + new_api_user = db_user._replace(in_guild=False) + users_to_update.add(new_api_user) + + new_user_ids = set(guild_users.keys()) - set(db_users.keys()) + for user_id in new_user_ids: + # The user is known on the guild but not on the API. This means + # that the user has joined since the last sync. Create it. + new_user = guild_users[user_id] + users_to_create.add(new_user) + + return _Diff(users_to_create, users_to_update, None) + + async def _sync(self, diff: _Diff) -> None: + """Synchronise the database with the user cache of `guild`.""" + log.trace("Syncing created users...") + for user in diff.created: + await self.bot.api_client.post('bot/users', json=user._asdict()) + + log.trace("Syncing updated users...") + for user in diff.updated: + await self.bot.api_client.put(f'bot/users/{user.id}', json=user._asdict()) diff --git a/bot/cogs/tags.py b/bot/cogs/tags.py index 54a51921c..c6b442912 100644 --- a/bot/cogs/tags.py +++ b/bot/cogs/tags.py @@ -1,7 +1,7 @@ import logging import re import time -from typing import Dict, List, Optional +from typing import Callable, Dict, Iterable, List, Optional from discord import Colour, Embed from discord.ext.commands import Cog, Context, group @@ -15,8 +15,7 @@ from bot.pagination import LinePaginator log = logging.getLogger(__name__) TEST_CHANNELS = ( - Channels.devtest, - Channels.bot, + Channels.bot_commands, Channels.helpers ) @@ -87,11 +86,80 @@ class Tags(Cog): return self._get_suggestions(tag_name) return found + async def _get_tags_via_content(self, check: Callable[[Iterable], bool], keywords: str) -> list: + """ + Search for tags via contents. + + `predicate` will be the built-in any, all, or a custom callable. Must return a bool. + """ + await self._get_tags() + + keywords_processed: List[str] = [] + for keyword in keywords.split(','): + keyword_sanitized = keyword.strip().casefold() + if not keyword_sanitized: + # this happens when there are leading / trailing / consecutive comma. + continue + keywords_processed.append(keyword_sanitized) + + if not keywords_processed: + # after sanitizing, we can end up with an empty list, for example when keywords is ',' + # in that case, we simply want to search for such keywords directly instead. + keywords_processed = [keywords] + + matching_tags = [] + for tag in self._cache.values(): + if check(query in tag['embed']['description'].casefold() for query in keywords_processed): + matching_tags.append(tag) + + return matching_tags + + async def _send_matching_tags(self, ctx: Context, keywords: str, matching_tags: list) -> None: + """Send the result of matching tags to user.""" + if not matching_tags: + pass + elif len(matching_tags) == 1: + await ctx.send(embed=Embed().from_dict(matching_tags[0]['embed'])) + else: + is_plural = keywords.strip().count(' ') > 0 or keywords.strip().count(',') > 0 + embed = Embed( + title=f"Here are the tags containing the given keyword{'s' * is_plural}:", + description='\n'.join(tag['title'] for tag in matching_tags[:10]) + ) + await LinePaginator.paginate( + sorted(f"**»** {tag['title']}" for tag in matching_tags), + ctx, + embed, + footer_text="To show a tag, type !tags <tagname>.", + empty=False, + max_lines=15 + ) + @group(name='tags', aliases=('tag', 't'), invoke_without_command=True) async def tags_group(self, ctx: Context, *, tag_name: TagNameConverter = None) -> None: """Show all known tags, a single tag, or run a subcommand.""" await ctx.invoke(self.get_command, tag_name=tag_name) + @tags_group.group(name='search', invoke_without_command=True) + async def search_tag_content(self, ctx: Context, *, keywords: str) -> None: + """ + Search inside tags' contents for tags. Allow searching for multiple keywords separated by comma. + + Only search for tags that has ALL the keywords. + """ + matching_tags = await self._get_tags_via_content(all, keywords) + await self._send_matching_tags(ctx, keywords, matching_tags) + + @search_tag_content.command(name='any') + async def search_tag_content_any_keyword(self, ctx: Context, *, keywords: Optional[str] = None) -> None: + """ + Search inside tags' contents for tags. Allow searching for multiple keywords separated by comma. + + Search for tags that has ANY of the keywords. + """ + matching_tags = await self._get_tags_via_content(any, keywords or 'any') + await self._send_matching_tags(ctx, keywords, matching_tags) + @tags_group.command(name='get', aliases=('show', 'g')) async def get_command(self, ctx: Context, *, tag_name: TagNameConverter = None) -> None: """Get a specified tag, or a list of all tags if no tag is specified.""" @@ -116,8 +184,10 @@ class Tags(Cog): if _command_on_cooldown(tag_name): time_left = Cooldowns.tags - (time.time() - self.tag_cooldowns[tag_name]["time"]) - log.warning(f"{ctx.author} tried to get the '{tag_name}' tag, but the tag is on cooldown. " - f"Cooldown ends in {time_left:.1f} seconds.") + log.info( + f"{ctx.author} tried to get the '{tag_name}' tag, but the tag is on cooldown. " + f"Cooldown ends in {time_left:.1f} seconds." + ) return await self._get_tags() @@ -219,7 +289,7 @@ class Tags(Cog): )) @tags_group.command(name='delete', aliases=('remove', 'rm', 'd')) - @with_role(Roles.admin, Roles.owner) + @with_role(Roles.admins, Roles.owners) async def delete_command(self, ctx: Context, *, tag_name: TagNameConverter) -> None: """Remove a tag from the database.""" await self.bot.api_client.delete(f'bot/tags/{tag_name}') diff --git a/bot/cogs/token_remover.py b/bot/cogs/token_remover.py index 82c01ae96..547ba8da0 100644 --- a/bot/cogs/token_remover.py +++ b/bot/cogs/token_remover.py @@ -96,12 +96,19 @@ class TokenRemover(Cog): if msg.author.bot: return False - maybe_match = TOKEN_RE.search(msg.content) - if maybe_match is None: + # Use findall rather than search to guard against method calls prematurely returning the + # token check (e.g. `message.channel.send` also matches our token pattern) + maybe_matches = TOKEN_RE.findall(msg.content) + if not maybe_matches: return False + return any(cls.is_maybe_token(substr) for substr in maybe_matches) + + @classmethod + def is_maybe_token(cls, test_str: str) -> bool: + """Check the provided string to see if it is a seemingly valid token.""" try: - user_id, creation_timestamp, hmac = maybe_match.group(0).split('.') + user_id, creation_timestamp, hmac = test_str.split('.') except ValueError: return False diff --git a/bot/cogs/utils.py b/bot/cogs/utils.py index da278011a..024141d62 100644 --- a/bot/cogs/utils.py +++ b/bot/cogs/utils.py @@ -1,14 +1,15 @@ +import difflib import logging import re import unicodedata from asyncio import TimeoutError, sleep from email.parser import HeaderParser from io import StringIO -from typing import Tuple +from typing import Tuple, Union from dateutil import relativedelta from discord import Colour, Embed, Message, Role -from discord.ext.commands import Cog, Context, command +from discord.ext.commands import BadArgument, Cog, Context, command from bot.bot import Bot from bot.constants import Channels, MODERATION_ROLES, Mention, STAFF_ROLES @@ -17,6 +18,28 @@ from bot.utils.time import humanize_delta log = logging.getLogger(__name__) +ZEN_OF_PYTHON = """\ +Beautiful is better than ugly. +Explicit is better than implicit. +Simple is better than complex. +Complex is better than complicated. +Flat is better than nested. +Sparse is better than dense. +Readability counts. +Special cases aren't special enough to break the rules. +Although practicality beats purity. +Errors should never pass silently. +Unless explicitly silenced. +In the face of ambiguity, refuse the temptation to guess. +There should be one-- and preferably only one --obvious way to do it. +Although that way may not be obvious at first unless you're Dutch. +Now is better than never. +Although never is often better than *right* now. +If the implementation is hard to explain, it's a bad idea. +If the implementation is easy to explain, it may be a good idea. +Namespaces are one honking great idea -- let's do more of those! +""" + class Utils(Cog): """A selection of utilities which don't have a clear category.""" @@ -89,7 +112,7 @@ class Utils(Cog): await ctx.message.channel.send(embed=pep_embed) @command() - @in_channel(Channels.bot, bypass_roles=STAFF_ROLES) + @in_channel(Channels.bot_commands, bypass_roles=STAFF_ROLES) async def charinfo(self, ctx: Context, *, characters: str) -> None: """Shows you information on up to 25 unicode characters.""" match = re.match(r"<(a?):(\w+):(\d+)>", characters) @@ -173,6 +196,88 @@ class Utils(Cog): f"as I detected unauthorised use by {msg.author} (ID: {msg.author.id})." ) + @command() + async def zen(self, ctx: Context, *, search_value: Union[int, str, None] = None) -> None: + """ + Show the Zen of Python. + + Without any arguments, the full Zen will be produced. + If an integer is provided, the line with that index will be produced. + If a string is provided, the line which matches best will be produced. + """ + embed = Embed( + colour=Colour.blurple(), + title="The Zen of Python", + description=ZEN_OF_PYTHON + ) + + if search_value is None: + embed.title += ", by Tim Peters" + await ctx.send(embed=embed) + return + + zen_lines = ZEN_OF_PYTHON.splitlines() + + # handle if it's an index int + if isinstance(search_value, int): + upper_bound = len(zen_lines) - 1 + lower_bound = -1 * upper_bound + if not (lower_bound <= search_value <= upper_bound): + raise BadArgument(f"Please provide an index between {lower_bound} and {upper_bound}.") + + embed.title += f" (line {search_value % len(zen_lines)}):" + embed.description = zen_lines[search_value] + await ctx.send(embed=embed) + return + + # handle if it's a search string + matcher = difflib.SequenceMatcher(None, search_value.lower()) + + best_match = "" + match_index = 0 + best_ratio = 0 + + for index, line in enumerate(zen_lines): + matcher.set_seq2(line.lower()) + + # the match ratio needs to be adjusted because, naturally, + # longer lines will have worse ratios than shorter lines when + # fuzzy searching for keywords. this seems to work okay. + adjusted_ratio = (len(line) - 5) ** 0.5 * matcher.ratio() + + if adjusted_ratio > best_ratio: + best_ratio = adjusted_ratio + best_match = line + match_index = index + + if not best_match: + raise BadArgument("I didn't get a match! Please try again with a different search term.") + + embed.title += f" (line {match_index}):" + embed.description = best_match + await ctx.send(embed=embed) + + @command(aliases=("poll",)) + @with_role(*MODERATION_ROLES) + async def vote(self, ctx: Context, title: str, *options: str) -> None: + """ + Build a quick voting poll with matching reactions with the provided options. + + A maximum of 20 options can be provided, as Discord supports a max of 20 + reactions on a single message. + """ + if len(options) < 2: + raise BadArgument("Please provide at least 2 options.") + if len(options) > 20: + raise BadArgument("I can only handle 20 options!") + + codepoint_start = 127462 # represents "regional_indicator_a" unicode value + options = {chr(i): f"{chr(i)} - {v}" for i, v in enumerate(options, start=codepoint_start)} + embed = Embed(title=title, description="\n".join(options.values())) + message = await ctx.send(embed=embed) + for reaction in options: + await message.add_reaction(reaction) + def setup(bot: Bot) -> None: """Load the Utils cog.""" diff --git a/bot/cogs/verification.py b/bot/cogs/verification.py index 988e0d49a..57b50c34f 100644 --- a/bot/cogs/verification.py +++ b/bot/cogs/verification.py @@ -1,7 +1,8 @@ import logging +from contextlib import suppress from datetime import datetime -from discord import Colour, Message, NotFound, Object +from discord import Colour, Forbidden, Message, NotFound, Object from discord.ext import tasks from discord.ext.commands import Cog, Context, command @@ -29,15 +30,16 @@ your information removed here as well. Feel free to review them at any point! Additionally, if you'd like to receive notifications for the announcements we post in <#{Channels.announcements}> \ -from time to time, you can send `!subscribe` to <#{Channels.bot}> at any time to assign yourself the \ +from time to time, you can send `!subscribe` to <#{Channels.bot_commands}> at any time to assign yourself the \ **Announcements** role. We'll mention this role every time we make an announcement. -If you'd like to unsubscribe from the announcement notifications, simply send `!unsubscribe` to <#{Channels.bot}>. +If you'd like to unsubscribe from the announcement notifications, simply send `!unsubscribe` to \ +<#{Channels.bot_commands}>. """ PERIODIC_PING = ( f"@everyone To verify that you have read our rules, please type `{BotConfig.prefix}accept`." - f" If you encounter any problems during the verification process, ping the <@&{Roles.admin}> role in this channel." + f" If you encounter any problems during the verification process, ping the <@&{Roles.admins}> role in this channel." ) BOT_MESSAGE_DELETE_DELAY = 10 @@ -92,19 +94,21 @@ class Verification(Cog): ping_everyone=Filter.ping_everyone, ) - ctx = await self.bot.get_context(message) # type: Context - + ctx: Context = await self.bot.get_context(message) if ctx.command is not None and ctx.command.name == "accept": - return # They used the accept command + return - for role in ctx.author.roles: - if role.id == Roles.verified: - log.warning(f"{ctx.author} posted '{ctx.message.content}' " - "in the verification channel, but is already verified.") - return # They're already verified + if any(r.id == Roles.verified for r in ctx.author.roles): + log.info( + f"{ctx.author} posted '{ctx.message.content}' " + "in the verification channel, but is already verified." + ) + return - log.debug(f"{ctx.author} posted '{ctx.message.content}' in the verification " - "channel. We are providing instructions how to verify.") + log.debug( + f"{ctx.author} posted '{ctx.message.content}' in the verification " + "channel. We are providing instructions how to verify." + ) await ctx.send( f"{ctx.author.mention} Please type `!accept` to verify that you accept our rules, " f"and gain access to the rest of the server.", @@ -112,11 +116,8 @@ class Verification(Cog): ) log.trace(f"Deleting the message posted by {ctx.author}") - - try: + with suppress(NotFound): await ctx.message.delete() - except NotFound: - log.trace("No message found, it must have been deleted by another bot.") @command(name='accept', aliases=('verify', 'verified', 'accepted'), hidden=True) @without_role(Roles.verified) @@ -127,20 +128,16 @@ class Verification(Cog): await ctx.author.add_roles(Object(Roles.verified), reason="Accepted the rules") try: await ctx.author.send(WELCOME_MESSAGE) - except Exception: - # Catch the exception, in case they have DMs off or something - log.exception(f"Unable to send welcome message to user {ctx.author}.") - - log.trace(f"Deleting the message posted by {ctx.author}.") - - try: - self.mod_log.ignore(Event.message_delete, ctx.message.id) - await ctx.message.delete() - except NotFound: - log.trace("No message found, it must have been deleted by another bot.") + except Forbidden: + log.info(f"Sending welcome message failed for {ctx.author}.") + finally: + log.trace(f"Deleting accept message by {ctx.author}.") + with suppress(NotFound): + self.mod_log.ignore(Event.message_delete, ctx.message.id) + await ctx.message.delete() @command(name='subscribe') - @in_channel(Channels.bot) + @in_channel(Channels.bot_commands) async def subscribe_command(self, ctx: Context, *_) -> None: # We don't actually care about the args """Subscribe to announcement notifications by assigning yourself the role.""" has_role = False @@ -164,7 +161,7 @@ class Verification(Cog): ) @command(name='unsubscribe') - @in_channel(Channels.bot) + @in_channel(Channels.bot_commands) async def unsubscribe_command(self, ctx: Context, *_) -> None: # We don't actually care about the args """Unsubscribe from announcement notifications by removing the role from yourself.""" has_role = False @@ -223,7 +220,7 @@ class Verification(Cog): @periodic_ping.before_loop async def before_ping(self) -> None: """Only start the loop when the bot is ready.""" - await self.bot.wait_until_ready() + await self.bot.wait_until_guild_available() def cog_unload(self) -> None: """Cancel the periodic ping task when the cog is unloaded.""" diff --git a/bot/cogs/watchchannels/watchchannel.py b/bot/cogs/watchchannels/watchchannel.py index eb787b083..479820444 100644 --- a/bot/cogs/watchchannels/watchchannel.py +++ b/bot/cogs/watchchannels/watchchannel.py @@ -9,7 +9,7 @@ from typing import Optional import dateutil.parser import discord -from discord import Color, Embed, HTTPException, Message, errors +from discord import Color, DMChannel, Embed, HTTPException, Message, errors from discord.ext.commands import Cog, Context from bot.api import ResponseCodeError @@ -91,7 +91,7 @@ class WatchChannel(metaclass=CogABCMeta): async def start_watchchannel(self) -> None: """Starts the watch channel by getting the channel, webhook, and user cache ready.""" - await self.bot.wait_until_ready() + await self.bot.wait_until_guild_available() try: self.channel = await self.bot.fetch_channel(self.destination) @@ -273,7 +273,14 @@ class WatchChannel(metaclass=CogABCMeta): reason = self.watched_users[user_id]['reason'] - embed = Embed(description=f"{msg.author.mention} in [#{msg.channel.name}]({msg.jump_url})") + if isinstance(msg.channel, DMChannel): + # If a watched user DMs the bot there won't be a channel name or jump URL + # This could technically include a GroupChannel but bot's can't be in those + message_jump = "via DM" + else: + message_jump = f"in [#{msg.channel.name}]({msg.jump_url})" + + embed = Embed(description=f"{msg.author.mention} {message_jump}") embed.set_footer(text=f"Added {time_delta} by {actor} | Reason: {reason}") await self.webhook_send(embed=embed, username=msg.author.display_name, avatar_url=msg.author.avatar_url) diff --git a/bot/constants.py b/bot/constants.py index fe8e57322..14f8dc094 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -186,6 +186,11 @@ class YAMLGetter(type): def __getitem__(cls, name): return cls.__getattr__(name) + def __iter__(cls): + """Return generator of key: value pairs of current constants class' config values.""" + for name in cls.__annotations__: + yield name, getattr(cls, name) + # Dataclasses class Bot(metaclass=YAMLGetter): @@ -193,7 +198,7 @@ class Bot(metaclass=YAMLGetter): prefix: str token: str - + sentry_dsn: str class Filter(metaclass=YAMLGetter): section = "filter" @@ -263,6 +268,7 @@ class Emojis(metaclass=YAMLGetter): new: str pencil: str cross_mark: str + check_mark: str ducky_yellow: int ducky_blurple: int @@ -357,16 +363,16 @@ class Channels(metaclass=YAMLGetter): section = "guild" subsection = "channels" - admins: int admin_spam: int + admins: int announcements: int attachment_log: int big_brother_logs: int - bot: int - checkpoint_test: int + bot_commands: int defcon: int - devlog: int - devtest: int + dev_contrib: int + dev_core: int + dev_log: int esoteric: int help_0: int help_1: int @@ -379,19 +385,19 @@ class Channels(metaclass=YAMLGetter): helpers: int message_log: int meta: int + mod_alerts: int + mod_log: int mod_spam: int mods: int - mod_alerts: int - modlog: int off_topic_0: int off_topic_1: int off_topic_2: int organisation: int - python: int + python_discussion: int reddit: int talent_pool: int - userlog: int - user_event_a: int + user_event_announcements: int + user_log: int verification: int voice_log: int @@ -404,25 +410,25 @@ class Webhooks(metaclass=YAMLGetter): big_brother: int reddit: int duck_pond: int + dev_log: int class Roles(metaclass=YAMLGetter): section = "guild" subsection = "roles" - admin: int + admins: int announcements: int - champion: int - contributor: int - core_developer: int + contributors: int + core_developers: int helpers: int - jammer: int - moderator: int + jammers: int + moderators: int muted: int - owner: int + owners: int partners: int - rockstars: int - team_leader: int + python_community: int + team_leaders: int verified: int # This is the Developers role on PyDis, here named verified for readability reasons. @@ -430,9 +436,12 @@ class Guild(metaclass=YAMLGetter): section = "guild" id: int - ignored: List[int] + moderation_channels: List[int] + moderation_roles: List[int] + modlog_blacklist: List[int] + reminder_whitelist: List[int] staff_channels: List[int] - + staff_roles: List[int] class Keys(metaclass=YAMLGetter): section = "keys" @@ -537,6 +546,13 @@ class RedirectOutput(metaclass=YAMLGetter): delete_delay: int +class Sync(metaclass=YAMLGetter): + section = 'sync' + + confirm_timeout: int + max_diff: int + + class Event(Enum): """ Event names. This does not include every event (for example, raw @@ -571,14 +587,14 @@ BOT_DIR = os.path.dirname(__file__) PROJECT_ROOT = os.path.abspath(os.path.join(BOT_DIR, os.pardir)) # Default role combinations -MODERATION_ROLES = Roles.moderator, Roles.admin, Roles.owner -STAFF_ROLES = Roles.helpers, Roles.moderator, Roles.admin, Roles.owner +MODERATION_ROLES = Guild.moderation_roles +STAFF_ROLES = Guild.staff_roles # Roles combinations STAFF_CHANNELS = Guild.staff_channels # Default Channel combinations -MODERATION_CHANNELS = Channels.admins, Channels.admin_spam, Channels.mod_alerts, Channels.mods, Channels.mod_spam +MODERATION_CHANNELS = Guild.moderation_channels # Bot replies diff --git a/bot/converters.py b/bot/converters.py index cca57a02d..1945e1da3 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -141,40 +141,24 @@ class TagNameConverter(Converter): @staticmethod async def convert(ctx: Context, tag_name: str) -> str: """Lowercase & strip whitespace from proposed tag_name & ensure it's valid.""" - def is_number(value: str) -> bool: - """Check to see if the input string is numeric.""" - try: - float(value) - except ValueError: - return False - return True - tag_name = tag_name.lower().strip() # The tag name has at least one invalid character. if ascii(tag_name)[1:-1] != tag_name: - log.warning(f"{ctx.author} tried to put an invalid character in a tag name. " - "Rejecting the request.") raise BadArgument("Don't be ridiculous, you can't use that character!") # The tag name is either empty, or consists of nothing but whitespace. elif not tag_name: - log.warning(f"{ctx.author} tried to create a tag with a name consisting only of whitespace. " - "Rejecting the request.") raise BadArgument("Tag names should not be empty, or filled with whitespace.") - # The tag name is a number of some kind, we don't allow that. - elif is_number(tag_name): - log.warning(f"{ctx.author} tried to create a tag with a digit as its name. " - "Rejecting the request.") - raise BadArgument("Tag names can't be numbers.") - # The tag name is longer than 127 characters. elif len(tag_name) > 127: - log.warning(f"{ctx.author} tried to request a tag name with over 127 characters. " - "Rejecting the request.") raise BadArgument("Are you insane? That's way too long!") + # The tag name is ascii but does not contain any letters. + elif not any(character.isalpha() for character in tag_name): + raise BadArgument("Tag names must contain at least one letter.") + return tag_name @@ -192,8 +176,6 @@ class TagContentConverter(Converter): # The tag contents should not be empty, or filled with whitespace. if not tag_content: - log.warning(f"{ctx.author} tried to create a tag containing only whitespace. " - "Rejecting the request.") raise BadArgument("Tag contents should not be empty, or filled with whitespace.") return tag_content diff --git a/bot/pagination.py b/bot/pagination.py index 35870c040..90c8f849c 100644 --- a/bot/pagination.py +++ b/bot/pagination.py @@ -1,8 +1,9 @@ import asyncio import logging -from typing import Iterable, List, Optional, Tuple +import typing as t +from contextlib import suppress -from discord import Embed, Member, Message, Reaction +import discord from discord.abc import User from discord.ext.commands import Context, Paginator @@ -14,7 +15,7 @@ RIGHT_EMOJI = "\u27A1" # [:arrow_right:] LAST_EMOJI = "\u23ED" # [:track_next:] DELETE_EMOJI = constants.Emojis.trashcan # [:trashcan:] -PAGINATION_EMOJI = [FIRST_EMOJI, LEFT_EMOJI, RIGHT_EMOJI, LAST_EMOJI, DELETE_EMOJI] +PAGINATION_EMOJI = (FIRST_EMOJI, LEFT_EMOJI, RIGHT_EMOJI, LAST_EMOJI, DELETE_EMOJI) log = logging.getLogger(__name__) @@ -89,12 +90,12 @@ class LinePaginator(Paginator): @classmethod async def paginate( cls, - lines: Iterable[str], + lines: t.List[str], ctx: Context, - embed: Embed, + embed: discord.Embed, prefix: str = "", suffix: str = "", - max_lines: Optional[int] = None, + max_lines: t.Optional[int] = None, max_size: int = 500, empty: bool = True, restrict_to_user: User = None, @@ -102,7 +103,7 @@ class LinePaginator(Paginator): footer_text: str = None, url: str = None, exception_on_empty_embed: bool = False - ) -> Optional[Message]: + ) -> t.Optional[discord.Message]: """ Use a paginator and set of reactions to provide pagination over a set of lines. @@ -114,11 +115,11 @@ class LinePaginator(Paginator): Pagination will also be removed automatically if no reaction is added for five minutes (300 seconds). Example: - >>> embed = Embed() + >>> embed = discord.Embed() >>> embed.set_author(name="Some Operation", url=url, icon_url=icon) - >>> await LinePaginator.paginate((line for line in lines), ctx, embed) + >>> await LinePaginator.paginate([line for line in lines], ctx, embed) """ - def event_check(reaction_: Reaction, user_: Member) -> bool: + def event_check(reaction_: discord.Reaction, user_: discord.Member) -> bool: """Make sure that this reaction is what we want to operate on.""" no_restrictions = ( # Pagination is not restricted @@ -281,8 +282,9 @@ class LinePaginator(Paginator): await message.edit(embed=embed) - log.debug("Ending pagination and removing all reactions...") - await message.clear_reactions() + log.debug("Ending pagination and clearing reactions.") + with suppress(discord.NotFound): + await message.clear_reactions() class ImagePaginator(Paginator): @@ -299,6 +301,7 @@ class ImagePaginator(Paginator): self._current_page = [prefix] self.images = [] self._pages = [] + self._count = 0 def add_line(self, line: str = '', *, empty: bool = False) -> None: """Adds a line to each page.""" @@ -316,13 +319,13 @@ class ImagePaginator(Paginator): @classmethod async def paginate( cls, - pages: List[Tuple[str, str]], - ctx: Context, embed: Embed, + pages: t.List[t.Tuple[str, str]], + ctx: Context, embed: discord.Embed, prefix: str = "", suffix: str = "", timeout: int = 300, exception_on_empty_embed: bool = False - ) -> Optional[Message]: + ) -> t.Optional[discord.Message]: """ Use a paginator and set of reactions to provide pagination over a set of title/image pairs. @@ -334,11 +337,11 @@ class ImagePaginator(Paginator): Note: Pagination will be removed automatically if no reaction is added for five minutes (300 seconds). Example: - >>> embed = Embed() + >>> embed = discord.Embed() >>> embed.set_author(name="Some Operation", url=url, icon_url=icon) >>> await ImagePaginator.paginate(pages, ctx, embed) """ - def check_event(reaction_: Reaction, member: Member) -> bool: + def check_event(reaction_: discord.Reaction, member: discord.Member) -> bool: """Checks each reaction added, if it matches our conditions pass the wait_for.""" return all(( # Reaction is on the same message sent @@ -410,7 +413,7 @@ class ImagePaginator(Paginator): log.debug("Got last page reaction, but we're on the last page - ignoring") continue - current_page = len(paginator.pages - 1) + current_page = len(paginator.pages) - 1 reaction_type = "last" # Previous reaction press - [:arrow_left: ] @@ -445,5 +448,6 @@ class ImagePaginator(Paginator): await message.edit(embed=embed) - log.debug("Ending pagination and removing all reactions...") - await message.clear_reactions() + log.debug("Ending pagination and clearing reactions.") + with suppress(discord.NotFound): + await message.clear_reactions() diff --git a/bot/rules/attachments.py b/bot/rules/attachments.py index 00bb2a949..8903c385c 100644 --- a/bot/rules/attachments.py +++ b/bot/rules/attachments.py @@ -19,7 +19,7 @@ async def apply( if total_recent_attachments > config['max']: return ( - f"sent {total_recent_attachments} attachments in {config['max']}s", + f"sent {total_recent_attachments} attachments in {config['interval']}s", (last_message.author,), relevant_messages ) diff --git a/bot/utils/__init__.py b/bot/utils/__init__.py index 8184be824..9b32e515d 100644 --- a/bot/utils/__init__.py +++ b/bot/utils/__init__.py @@ -1,5 +1,4 @@ from abc import ABCMeta -from typing import Any, Generator, Hashable, Iterable from discord.ext.commands import CogMeta @@ -8,69 +7,3 @@ class CogABCMeta(CogMeta, ABCMeta): """Metaclass for ABCs meant to be implemented as Cogs.""" pass - - -class CaseInsensitiveDict(dict): - """ - We found this class on StackOverflow. Thanks to m000 for writing it! - - https://stackoverflow.com/a/32888599/4022104 - """ - - @classmethod - def _k(cls, key: Hashable) -> Hashable: - """Return lowered key if a string-like is passed, otherwise pass key straight through.""" - return key.lower() if isinstance(key, str) else key - - def __init__(self, *args, **kwargs): - super(CaseInsensitiveDict, self).__init__(*args, **kwargs) - self._convert_keys() - - def __getitem__(self, key: Hashable) -> Any: - """Case insensitive __setitem__.""" - return super(CaseInsensitiveDict, self).__getitem__(self.__class__._k(key)) - - def __setitem__(self, key: Hashable, value: Any): - """Case insensitive __setitem__.""" - super(CaseInsensitiveDict, self).__setitem__(self.__class__._k(key), value) - - def __delitem__(self, key: Hashable) -> Any: - """Case insensitive __delitem__.""" - return super(CaseInsensitiveDict, self).__delitem__(self.__class__._k(key)) - - def __contains__(self, key: Hashable) -> bool: - """Case insensitive __contains__.""" - return super(CaseInsensitiveDict, self).__contains__(self.__class__._k(key)) - - def pop(self, key: Hashable, *args, **kwargs) -> Any: - """Case insensitive pop.""" - return super(CaseInsensitiveDict, self).pop(self.__class__._k(key), *args, **kwargs) - - def get(self, key: Hashable, *args, **kwargs) -> Any: - """Case insensitive get.""" - return super(CaseInsensitiveDict, self).get(self.__class__._k(key), *args, **kwargs) - - def setdefault(self, key: Hashable, *args, **kwargs) -> Any: - """Case insensitive setdefault.""" - return super(CaseInsensitiveDict, self).setdefault(self.__class__._k(key), *args, **kwargs) - - def update(self, E: Any = None, **F) -> None: - """Case insensitive update.""" - super(CaseInsensitiveDict, self).update(self.__class__(E)) - super(CaseInsensitiveDict, self).update(self.__class__(**F)) - - def _convert_keys(self) -> None: - """Helper method to lowercase all existing string-like keys.""" - for k in list(self.keys()): - v = super(CaseInsensitiveDict, self).pop(k) - self.__setitem__(k, v) - - -def chunks(iterable: Iterable, size: int) -> Generator[Any, None, None]: - """ - Generator that allows you to iterate over any indexable collection in `size`-length chunks. - - Found: https://stackoverflow.com/a/312464/4022104 - """ - for i in range(0, len(iterable), size): - yield iterable[i:i + size] diff --git a/bot/utils/scheduling.py b/bot/utils/scheduling.py index ee6c0a8e6..5760ec2d4 100644 --- a/bot/utils/scheduling.py +++ b/bot/utils/scheduling.py @@ -1,8 +1,9 @@ import asyncio import contextlib import logging +import typing as t from abc import abstractmethod -from typing import Coroutine, Dict, Union +from functools import partial from bot.utils import CogABCMeta @@ -13,12 +14,13 @@ class Scheduler(metaclass=CogABCMeta): """Task scheduler.""" def __init__(self): + # Keep track of the child cog's name so the logs are clear. + self.cog_name = self.__class__.__name__ - self.cog_name = self.__class__.__name__ # keep track of the child cog's name so the logs are clear. - self.scheduled_tasks: Dict[str, asyncio.Task] = {} + self._scheduled_tasks: t.Dict[t.Hashable, asyncio.Task] = {} @abstractmethod - async def _scheduled_task(self, task_object: dict) -> None: + async def _scheduled_task(self, task_object: t.Any) -> None: """ A coroutine which handles the scheduling. @@ -29,46 +31,73 @@ class Scheduler(metaclass=CogABCMeta): then make a site API request to delete the reminder from the database. """ - def schedule_task(self, loop: asyncio.AbstractEventLoop, task_id: str, task_data: dict) -> None: + def schedule_task(self, task_id: t.Hashable, task_data: t.Any) -> None: """ Schedules a task. - `task_data` is passed to `Scheduler._scheduled_expiration` + `task_data` is passed to the `Scheduler._scheduled_task()` coroutine. """ - if task_id in self.scheduled_tasks: + log.trace(f"{self.cog_name}: scheduling task #{task_id}...") + + if task_id in self._scheduled_tasks: log.debug( f"{self.cog_name}: did not schedule task #{task_id}; task was already scheduled." ) return - task: asyncio.Task = create_task(loop, self._scheduled_task(task_data)) + task = asyncio.create_task(self._scheduled_task(task_data)) + task.add_done_callback(partial(self._task_done_callback, task_id)) - self.scheduled_tasks[task_id] = task - log.debug(f"{self.cog_name}: scheduled task #{task_id}.") + self._scheduled_tasks[task_id] = task + log.debug(f"{self.cog_name}: scheduled task #{task_id} {id(task)}.") - def cancel_task(self, task_id: str) -> None: - """Un-schedules a task.""" - task = self.scheduled_tasks.get(task_id) + def cancel_task(self, task_id: t.Hashable) -> None: + """Unschedule the task identified by `task_id`.""" + log.trace(f"{self.cog_name}: cancelling task #{task_id}...") + task = self._scheduled_tasks.get(task_id) - if task is None: - log.warning(f"{self.cog_name}: Failed to unschedule {task_id} (no task found).") + if not task: + log.warning(f"{self.cog_name}: failed to unschedule {task_id} (no task found).") return task.cancel() - log.debug(f"{self.cog_name}: unscheduled task #{task_id}.") - del self.scheduled_tasks[task_id] + del self._scheduled_tasks[task_id] + + log.debug(f"{self.cog_name}: unscheduled task #{task_id} {id(task)}.") + def _task_done_callback(self, task_id: t.Hashable, done_task: asyncio.Task) -> None: + """ + Delete the task and raise its exception if one exists. -def create_task(loop: asyncio.AbstractEventLoop, coro_or_future: Union[Coroutine, asyncio.Future]) -> asyncio.Task: - """Creates an asyncio.Task object from a coroutine or future object.""" - task: asyncio.Task = asyncio.ensure_future(coro_or_future, loop=loop) + If `done_task` and the task associated with `task_id` are different, then the latter + will not be deleted. In this case, a new task was likely rescheduled with the same ID. + """ + log.trace(f"{self.cog_name}: performing done callback for task #{task_id} {id(done_task)}.") - # Silently ignore exceptions in a callback (handles the CancelledError nonsense) - task.add_done_callback(_silent_exception) - return task + scheduled_task = self._scheduled_tasks.get(task_id) + if scheduled_task and done_task is scheduled_task: + # A task for the ID exists and its the same as the done task. + # Since this is the done callback, the task is already done so no need to cancel it. + log.trace(f"{self.cog_name}: deleting task #{task_id} {id(done_task)}.") + del self._scheduled_tasks[task_id] + elif scheduled_task: + # A new task was likely rescheduled with the same ID. + log.debug( + f"{self.cog_name}: the scheduled task #{task_id} {id(scheduled_task)} " + f"and the done task {id(done_task)} differ." + ) + elif not done_task.cancelled(): + log.warning( + f"{self.cog_name}: task #{task_id} not found while handling task {id(done_task)}! " + f"A task somehow got unscheduled improperly (i.e. deleted but not cancelled)." + ) -def _silent_exception(future: asyncio.Future) -> None: - """Suppress future's exception.""" - with contextlib.suppress(Exception): - future.exception() + with contextlib.suppress(asyncio.CancelledError): + exception = done_task.exception() + # Log the exception if one exists. + if exception: + log.error( + f"{self.cog_name}: error in task #{task_id} {id(scheduled_task)}!", + exc_info=exception + ) diff --git a/bot/utils/time.py b/bot/utils/time.py index 7416f36e0..77060143c 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -114,30 +114,40 @@ def format_infraction(timestamp: str) -> str: def format_infraction_with_duration( - expiry: Optional[str], + date_to: Optional[str], date_from: Optional[datetime.datetime] = None, - max_units: int = 2 + max_units: int = 2, + absolute: bool = True ) -> Optional[str]: """ - Format an infraction timestamp to a more readable ISO 8601 format WITH the duration. + Return `date_to` formatted as a readable ISO-8601 with the humanized duration since `date_from`. - Returns a human-readable version of the duration between datetime.utcnow() and an expiry. - Unlike `humanize_delta`, this function will force the `precision` to be `seconds` by not passing it. - `max_units` specifies the maximum number of units of time to include (e.g. 1 may include days but not hours). - By default, max_units is 2. + `date_from` must be an ISO-8601 formatted timestamp. The duration is calculated as from + `date_from` until `date_to` with a precision of seconds. If `date_from` is unspecified, the + current time is used. + + `max_units` specifies the maximum number of units of time to include in the duration. For + example, a value of 1 may include days but not hours. + + If `absolute` is True, the absolute value of the duration delta is used. This prevents negative + values in the case that `date_to` is in the past relative to `date_from`. """ - if not expiry: + if not date_to: return None + date_to_formatted = format_infraction(date_to) + date_from = date_from or datetime.datetime.utcnow() - date_to = dateutil.parser.isoparse(expiry).replace(tzinfo=None, microsecond=0) + date_to = dateutil.parser.isoparse(date_to).replace(tzinfo=None, microsecond=0) - expiry_formatted = format_infraction(expiry) + delta = relativedelta(date_to, date_from) + if absolute: + delta = abs(delta) - duration = humanize_delta(relativedelta(date_to, date_from), max_units=max_units) - duration_formatted = f" ({duration})" if duration else '' + duration = humanize_delta(delta, max_units=max_units) + duration_formatted = f" ({duration})" if duration else "" - return f"{expiry_formatted}{duration_formatted}" + return f"{date_to_formatted}{duration_formatted}" def until_expiration( diff --git a/config-default.yml b/config-default.yml index fda14b511..5788d1e12 100644 --- a/config-default.yml +++ b/config-default.yml @@ -1,6 +1,7 @@ bot: prefix: "!" token: !ENV "BOT_TOKEN" + sentry_dsn: !ENV "BOT_SENTRY_DSN" cooldowns: # Per channel, per tag. @@ -34,6 +35,7 @@ style: pencil: "\u270F" new: "\U0001F195" cross_mark: "\u274C" + check_mark: "\u2705" ducky_yellow: &DUCKY_YELLOW 574951975574175744 ducky_blurple: &DUCKY_BLURPLE 574951975310065675 @@ -109,74 +111,135 @@ guild: id: 267624335836053506 categories: - python_help: 356013061213126657 + python_help: 356013061213126657 channels: - admins: &ADMINS 365960823622991872 - admin_spam: &ADMIN_SPAM 563594791770914816 - admins_voice: &ADMINS_VOICE 500734494840717332 - announcements: 354619224620138496 - attachment_log: &ATTCH_LOG 649243850006855680 - big_brother_logs: &BBLOGS 468507907357409333 - bot: 267659945086812160 - checkpoint_test: 422077681434099723 - defcon: &DEFCON 464469101889454091 - devlog: &DEVLOG 622895325144940554 - devtest: &DEVTEST 414574275865870337 - esoteric: 470884583684964352 - help_0: 303906576991780866 - help_1: 303906556754395136 - help_2: 303906514266226689 - help_3: 439702951246692352 - help_4: 451312046647148554 - help_5: 454941769734422538 - help_6: 587375753306570782 - help_7: 587375768556797982 - helpers: &HELPERS 385474242440986624 - message_log: &MESSAGE_LOG 467752170159079424 - meta: 429409067623251969 - mod_spam: &MOD_SPAM 620607373828030464 - mods: &MODS 305126844661760000 - mod_alerts: 473092532147060736 - modlog: &MODLOG 282638479504965634 - off_topic_0: 291284109232308226 - off_topic_1: 463035241142026251 - off_topic_2: 463035268514185226 - organisation: &ORGANISATION 551789653284356126 - python: 267624335836053506 - reddit: 458224812528238616 - staff_lounge: &STAFF_LOUNGE 464905259261755392 - staff_voice: &STAFF_VOICE 412375055910043655 - talent_pool: &TALENT_POOL 534321732593647616 - userlog: 528976905546760203 - user_event_a: &USER_EVENT_A 592000283102674944 - verification: 352442727016693763 - voice_log: 640292421988646961 - - staff_channels: [*ADMINS, *ADMIN_SPAM, *MOD_SPAM, *MODS, *HELPERS, *ORGANISATION, *DEFCON] - ignored: [*ADMINS, *MESSAGE_LOG, *MODLOG, *ADMINS_VOICE, *STAFF_VOICE, *ATTCH_LOG] + announcements: 354619224620138496 + user_event_announcements: &USER_EVENT_A 592000283102674944 + + # Development + dev_contrib: &DEV_CONTRIB 635950537262759947 + dev_core: &DEV_CORE 411200599653351425 + dev_log: &DEV_LOG 622895325144940554 + + # Discussion + meta: 429409067623251969 + python_discussion: 267624335836053506 + + # Logs + attachment_log: &ATTACH_LOG 649243850006855680 + message_log: &MESSAGE_LOG 467752170159079424 + mod_log: &MOD_LOG 282638479504965634 + user_log: 528976905546760203 + voice_log: 640292421988646961 + + # Off-topic + off_topic_0: 291284109232308226 + off_topic_1: 463035241142026251 + off_topic_2: 463035268514185226 + + # Python Help + help_0: 303906576991780866 + help_1: 303906556754395136 + help_2: 303906514266226689 + help_3: 439702951246692352 + help_4: 451312046647148554 + help_5: 454941769734422538 + help_6: 587375753306570782 + help_7: 587375768556797982 + + # Special + bot_commands: &BOT_CMD 267659945086812160 + esoteric: 470884583684964352 + reddit: 458224812528238616 + verification: 352442727016693763 + + # Staff + admins: &ADMINS 365960823622991872 + admin_spam: &ADMIN_SPAM 563594791770914816 + defcon: &DEFCON 464469101889454091 + helpers: &HELPERS 385474242440986624 + mods: &MODS 305126844661760000 + mod_alerts: &MOD_ALERTS 473092532147060736 + mod_spam: &MOD_SPAM 620607373828030464 + organisation: &ORGANISATION 551789653284356126 + staff_lounge: &STAFF_LOUNGE 464905259261755392 + + # Voice + admins_voice: &ADMINS_VOICE 500734494840717332 + staff_voice: &STAFF_VOICE 412375055910043655 + + # Watch + big_brother_logs: &BB_LOGS 468507907357409333 + talent_pool: &TALENT_POOL 534321732593647616 + + staff_channels: + - *ADMINS + - *ADMIN_SPAM + - *DEFCON + - *HELPERS + - *MODS + - *MOD_SPAM + - *ORGANISATION + + moderation_channels: + - *ADMINS + - *ADMIN_SPAM + - *MOD_ALERTS + - *MODS + - *MOD_SPAM + + # Modlog cog ignores events which occur in these channels + modlog_blacklist: + - *ADMINS + - *ADMINS_VOICE + - *ATTACH_LOG + - *MESSAGE_LOG + - *MOD_LOG + - *STAFF_VOICE + + reminder_whitelist: + - *BOT_CMD + - *DEV_CONTRIB roles: - admin: &ADMIN_ROLE 267628507062992896 - announcements: 463658397560995840 - champion: 430492892331769857 - contributor: 295488872404484098 - core_developer: 587606783669829632 - helpers: 267630620367257601 - jammer: 591786436651646989 - moderator: &MOD_ROLE 267629731250176001 - muted: &MUTED_ROLE 277914926603829249 - owner: &OWNER_ROLE 267627879762755584 - partners: 323426753857191936 - rockstars: &ROCKSTARS_ROLE 458226413825294336 - team_leader: 501324292341104650 - verified: 352427296948486144 + announcements: 463658397560995840 + contributors: 295488872404484098 + muted: &MUTED_ROLE 277914926603829249 + partners: 323426753857191936 + python_community: &PY_COMMUNITY_ROLE 458226413825294336 + + # This is the Developers role on PyDis, here named verified for readability reasons + verified: 352427296948486144 + + # Staff + admins: &ADMINS_ROLE 267628507062992896 + core_developers: 587606783669829632 + helpers: &HELPERS_ROLE 267630620367257601 + moderators: &MODS_ROLE 267629731250176001 + owners: &OWNERS_ROLE 267627879762755584 + + # Code Jam + jammers: 591786436651646989 + team_leaders: 501324292341104650 + + moderation_roles: + - *OWNERS_ROLE + - *ADMINS_ROLE + - *MODS_ROLE + + staff_roles: + - *OWNERS_ROLE + - *ADMINS_ROLE + - *MODS_ROLE + - *HELPERS_ROLE webhooks: - talent_pool: 569145364800602132 - big_brother: 569133704568373283 - reddit: 635408384794951680 - duck_pond: 637821475327311927 + talent_pool: 569145364800602132 + big_brother: 569133704568373283 + reddit: 635408384794951680 + duck_pond: 637821475327311927 + dev_log: 680501655111729222 filter: @@ -216,10 +279,35 @@ filter: - 438622377094414346 # Pyglet - 524691714909274162 # Panda3D - 336642139381301249 # discord.py + - 405403391410438165 # Sentdex domain_blacklist: - pornhub.com - liveleak.com + - grabify.link + - bmwforum.co + - leancoding.co + - spottyfly.com + - stopify.co + - yoütu.be + - discörd.com + - minecräft.com + - freegiftcards.co + - disçordapp.com + - fortnight.space + - fortnitechat.site + - joinmy.site + - curiouscat.club + - catsnthings.fun + - yourtube.site + - youtubeshort.watch + - catsnthing.com + - youtubeshort.pro + - canadianlumberjacks.online + - poweredbydialup.club + - poweredbydialup.online + - poweredbysecurity.org + - poweredbysecurity.online word_watchlist: - goo+ks* @@ -253,20 +341,20 @@ filter: # Censor doesn't apply to these channel_whitelist: - *ADMINS - - *MODLOG + - *MOD_LOG - *MESSAGE_LOG - - *DEVLOG - - *BBLOGS + - *DEV_LOG + - *BB_LOGS - *STAFF_LOUNGE - - *DEVTEST - *TALENT_POOL - *USER_EVENT_A role_whitelist: - - *ADMIN_ROLE - - *MOD_ROLE - - *OWNER_ROLE - - *ROCKSTARS_ROLE + - *ADMINS_ROLE + - *MODS_ROLE + - *OWNERS_ROLE + - *HELPERS_ROLE + - *PY_COMMUNITY_ROLE keys: @@ -302,7 +390,7 @@ urls: paste_service: !JOIN [*SCHEMA, *PASTE, "/{key}"] # Snekbox - snekbox_eval_api: "https://snekbox.pythondiscord.com/eval" + snekbox_eval_api: "http://snekbox:8060/eval" # Discord API URLs discord_api: &DISCORD_API "https://discordapp.com/api/v7/" @@ -428,9 +516,26 @@ redirect_output: delete_invocation: true delete_delay: 15 +sync: + confirm_timeout: 300 + max_diff: 10 + duck_pond: threshold: 5 - custom_emojis: [*DUCKY_YELLOW, *DUCKY_BLURPLE, *DUCKY_CAMO, *DUCKY_DEVIL, *DUCKY_NINJA, *DUCKY_REGAL, *DUCKY_TUBE, *DUCKY_HUNT, *DUCKY_WIZARD, *DUCKY_PARTY, *DUCKY_ANGEL, *DUCKY_MAUL, *DUCKY_SANTA] + custom_emojis: + - *DUCKY_YELLOW + - *DUCKY_BLURPLE + - *DUCKY_CAMO + - *DUCKY_DEVIL + - *DUCKY_NINJA + - *DUCKY_REGAL + - *DUCKY_TUBE + - *DUCKY_HUNT + - *DUCKY_WIZARD + - *DUCKY_PARTY + - *DUCKY_ANGEL + - *DUCKY_MAUL + - *DUCKY_SANTA config: required_keys: ['bot.token'] diff --git a/docker-compose.yml b/docker-compose.yml index 7281c7953..11deceae8 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -23,6 +23,7 @@ services: - staff.web ports: - "127.0.0.1:8000:8000" + tty: true depends_on: - postgres environment: @@ -37,6 +38,7 @@ services: volumes: - ./logs:/bot/logs - .:/bot:ro + tty: true depends_on: - web environment: diff --git a/tests/README.md b/tests/README.md index be78821bf..4f62edd68 100644 --- a/tests/README.md +++ b/tests/README.md @@ -83,7 +83,7 @@ TagContentConverter should return correct values for valid input. As we are trying to test our "units" of code independently, we want to make sure that we do not rely objects and data generated by "external" code. If we we did, then we wouldn't know if the failure we're observing was caused by the code we are actually trying to test or something external to it. -However, the features that we are trying to test often depend on those objects generated by external pieces of code. It would be difficult to test a bot command without having access to a `Context` instance. Fortunately, there's a solution for that: we use fake objects that act like the true object. We call these fake objects "mocks". +However, the features that we are trying to test often depend on those objects generated by external pieces of code. It would be difficult to test a bot command without having access to a `Context` instance. Fortunately, there's a solution for that: we use fake objects that act like the true object. We call these fake objects "mocks". To create these mock object, we mainly use the [`unittest.mock`](https://docs.python.org/3/library/unittest.mock.html) module. In addition, we have also defined a couple of specialized mock objects that mock specific `discord.py` types (see the section on the below.). @@ -114,13 +114,13 @@ class BotCogTests(unittest.TestCase): ### Mocking coroutines -By default, the `unittest.mock.Mock` and `unittest.mock.MagicMock` classes cannot mock coroutines, since the `__call__` method they provide is synchronous. In anticipation of the `AsyncMock` that will be [introduced in Python 3.8](https://docs.python.org/3.9/whatsnew/3.8.html#unittest), we have added an `AsyncMock` helper to [`helpers.py`](/tests/helpers.py). Do note that this drop-in replacement only implements an asynchronous `__call__` method, not the additional assertions that will come with the new `AsyncMock` type in Python 3.8. +By default, the `unittest.mock.Mock` and `unittest.mock.MagicMock` classes cannot mock coroutines, since the `__call__` method they provide is synchronous. In anticipation of the `AsyncMock` that will be [introduced in Python 3.8](https://docs.python.org/3.9/whatsnew/3.8.html#unittest), we have added an `AsyncMock` helper to [`helpers.py`](/tests/helpers.py). Do note that this drop-in replacement only implements an asynchronous `__call__` method, not the additional assertions that will come with the new `AsyncMock` type in Python 3.8. ### Special mocks for some `discord.py` types To quote Ned Batchelder, Mock objects are "automatic chameleons". This means that they will happily allow the access to any attribute or method and provide a mocked value in return. One downside to this is that if the code you are testing gets the name of the attribute wrong, your mock object will not complain and the test may still pass. -In order to avoid that, we have defined a number of Mock types in [`helpers.py`](/tests/helpers.py) that follow the specifications of the actual Discord types they are mocking. This means that trying to access an attribute or method on a mocked object that does not exist on the equivalent `discord.py` object will result in an `AttributeError`. In addition, these mocks have some sensible defaults and **pass `isinstance` checks for the types they are mocking**. +In order to avoid that, we have defined a number of Mock types in [`helpers.py`](/tests/helpers.py) that follow the specifications of the actual Discord types they are mocking. This means that trying to access an attribute or method on a mocked object that does not exist on the equivalent `discord.py` object will result in an `AttributeError`. In addition, these mocks have some sensible defaults and **pass `isinstance` checks for the types they are mocking**. These special mocks are added when they are needed, so if you think it would be sensible to add another one, feel free to propose one in your PR. @@ -144,7 +144,7 @@ Finally, there are some considerations to make when writing tests, both for writ ### Test coverage is a starting point -Having test coverage is a good starting point for unit testing: If a part of your code was not covered by a test, we know that we have not tested it properly. The reverse is unfortunately not true: Even if the code we are testing has 100% branch coverage, it does not mean it's fully tested or guaranteed to work. +Having test coverage is a good starting point for unit testing: If a part of your code was not covered by a test, we know that we have not tested it properly. The reverse is unfortunately not true: Even if the code we are testing has 100% branch coverage, it does not mean it's fully tested or guaranteed to work. One problem is that 100% branch coverage may be misleading if we haven't tested our code against all the realistic input it may get in production. For instance, take a look at the following `member_information` function and the test we've written for it: @@ -169,7 +169,7 @@ class FunctionsTests(unittest.TestCase): If you were to run this test, not only would the function pass the test, `coverage.py` will also tell us that the test provides 100% branch coverage for the function. Can you spot the bug the test suite did not catch? -The problem here is that we have only tested our function with a member object that had `None` for the `member.joined` attribute. This means that `member.joined.stfptime("%d-%m-%Y")` was never executed during our test, leading to us missing the spelling mistake in `stfptime` (it should be `strftime`). +The problem here is that we have only tested our function with a member object that had `None` for the `member.joined` attribute. This means that `member.joined.stfptime("%d-%m-%Y")` was never executed during our test, leading to us missing the spelling mistake in `stfptime` (it should be `strftime`). Adding another test would not increase the test coverage we have, but it does ensure that we'll notice that this function can fail with realistic data: diff --git a/tests/base.py b/tests/base.py index 029a249ed..d99b9ac31 100644 --- a/tests/base.py +++ b/tests/base.py @@ -1,6 +1,12 @@ import logging import unittest from contextlib import contextmanager +from typing import Dict + +import discord +from discord.ext import commands + +from tests import helpers class _CaptureLogHandler(logging.Handler): @@ -16,11 +22,16 @@ class _CaptureLogHandler(logging.Handler): self.records.append(record) -class LoggingTestCase(unittest.TestCase): - """TestCase subclass that adds more logging assertion tools.""" +class LoggingTestsMixin: + """ + A mixin that defines additional test methods for logging behavior. + + This mixin relies on the availability of the `fail` attribute defined by the + test classes included in Python's unittest method to signal test failure. + """ @contextmanager - def assertNotLogs(self, logger=None, level=None, msg=None): + def assertNotLogs(self, logger=None, level=None, msg=None): # noqa: N802 """ Asserts that no logs of `level` and higher were emitted by `logger`. @@ -65,3 +76,30 @@ class LoggingTestCase(unittest.TestCase): standard_message = self._truncateMessage(base_message, record_message) msg = self._formatMessage(msg, standard_message) self.fail(msg) + + +class CommandTestCase(unittest.IsolatedAsyncioTestCase): + """TestCase with additional assertions that are useful for testing Discord commands.""" + + async def assertHasPermissionsCheck( # noqa: N802 + self, + cmd: commands.Command, + permissions: Dict[str, bool], + ) -> None: + """ + Test that `cmd` raises a `MissingPermissions` exception if author lacks `permissions`. + + Every permission in `permissions` is expected to be reported as missing. In other words, do + not include permissions which should not raise an exception along with those which should. + """ + # Invert permission values because it's more intuitive to pass to this assertion the same + # permissions as those given to the check decorator. + permissions = {k: not v for k, v in permissions.items()} + + ctx = helpers.MockContext() + ctx.channel.permissions_for.return_value = discord.Permissions(**permissions) + + with self.assertRaises(commands.MissingPermissions) as cm: + await cmd.can_run(ctx) + + self.assertCountEqual(permissions.keys(), cm.exception.missing_perms) diff --git a/tests/bot/cogs/sync/test_base.py b/tests/bot/cogs/sync/test_base.py new file mode 100644 index 000000000..6ee9dfda6 --- /dev/null +++ b/tests/bot/cogs/sync/test_base.py @@ -0,0 +1,403 @@ +import unittest +from unittest import mock + +import discord + +from bot import constants +from bot.api import ResponseCodeError +from bot.cogs.sync.syncers import Syncer, _Diff +from tests import helpers + + +class TestSyncer(Syncer): + """Syncer subclass with mocks for abstract methods for testing purposes.""" + + name = "test" + _get_diff = mock.AsyncMock() + _sync = mock.AsyncMock() + + +class SyncerBaseTests(unittest.TestCase): + """Tests for the syncer base class.""" + + def setUp(self): + self.bot = helpers.MockBot() + + def test_instantiation_fails_without_abstract_methods(self): + """The class must have abstract methods implemented.""" + with self.assertRaisesRegex(TypeError, "Can't instantiate abstract class"): + Syncer(self.bot) + + +class SyncerSendPromptTests(unittest.IsolatedAsyncioTestCase): + """Tests for sending the sync confirmation prompt.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = TestSyncer(self.bot) + + def mock_get_channel(self): + """Fixture to return a mock channel and message for when `get_channel` is used.""" + self.bot.reset_mock() + + mock_channel = helpers.MockTextChannel() + mock_message = helpers.MockMessage() + + mock_channel.send.return_value = mock_message + self.bot.get_channel.return_value = mock_channel + + return mock_channel, mock_message + + def mock_fetch_channel(self): + """Fixture to return a mock channel and message for when `fetch_channel` is used.""" + self.bot.reset_mock() + + mock_channel = helpers.MockTextChannel() + mock_message = helpers.MockMessage() + + self.bot.get_channel.return_value = None + mock_channel.send.return_value = mock_message + self.bot.fetch_channel.return_value = mock_channel + + return mock_channel, mock_message + + async def test_send_prompt_edits_and_returns_message(self): + """The given message should be edited to display the prompt and then should be returned.""" + msg = helpers.MockMessage() + ret_val = await self.syncer._send_prompt(msg) + + msg.edit.assert_called_once() + self.assertIn("content", msg.edit.call_args[1]) + self.assertEqual(ret_val, msg) + + async def test_send_prompt_gets_dev_core_channel(self): + """The dev-core channel should be retrieved if an extant message isn't given.""" + subtests = ( + (self.bot.get_channel, self.mock_get_channel), + (self.bot.fetch_channel, self.mock_fetch_channel), + ) + + for method, mock_ in subtests: + with self.subTest(method=method, msg=mock_.__name__): + mock_() + await self.syncer._send_prompt() + + method.assert_called_once_with(constants.Channels.dev_core) + + async def test_send_prompt_returns_none_if_channel_fetch_fails(self): + """None should be returned if there's an HTTPException when fetching the channel.""" + self.bot.get_channel.return_value = None + self.bot.fetch_channel.side_effect = discord.HTTPException(mock.MagicMock(), "test error!") + + ret_val = await self.syncer._send_prompt() + + self.assertIsNone(ret_val) + + async def test_send_prompt_sends_and_returns_new_message_if_not_given(self): + """A new message mentioning core devs should be sent and returned if message isn't given.""" + for mock_ in (self.mock_get_channel, self.mock_fetch_channel): + with self.subTest(msg=mock_.__name__): + mock_channel, mock_message = mock_() + ret_val = await self.syncer._send_prompt() + + mock_channel.send.assert_called_once() + self.assertIn(self.syncer._CORE_DEV_MENTION, mock_channel.send.call_args[0][0]) + self.assertEqual(ret_val, mock_message) + + async def test_send_prompt_adds_reactions(self): + """The message should have reactions for confirmation added.""" + extant_message = helpers.MockMessage() + subtests = ( + (extant_message, lambda: (None, extant_message)), + (None, self.mock_get_channel), + (None, self.mock_fetch_channel), + ) + + for message_arg, mock_ in subtests: + subtest_msg = "Extant message" if mock_.__name__ == "<lambda>" else mock_.__name__ + + with self.subTest(msg=subtest_msg): + _, mock_message = mock_() + await self.syncer._send_prompt(message_arg) + + calls = [mock.call(emoji) for emoji in self.syncer._REACTION_EMOJIS] + mock_message.add_reaction.assert_has_calls(calls) + + +class SyncerConfirmationTests(unittest.IsolatedAsyncioTestCase): + """Tests for waiting for a sync confirmation reaction on the prompt.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = TestSyncer(self.bot) + self.core_dev_role = helpers.MockRole(id=constants.Roles.core_developers) + + @staticmethod + def get_message_reaction(emoji): + """Fixture to return a mock message an reaction from the given `emoji`.""" + message = helpers.MockMessage() + reaction = helpers.MockReaction(emoji=emoji, message=message) + + return message, reaction + + def test_reaction_check_for_valid_emoji_and_authors(self): + """Should return True if authors are identical or are a bot and a core dev, respectively.""" + user_subtests = ( + ( + helpers.MockMember(id=77), + helpers.MockMember(id=77), + "identical users", + ), + ( + helpers.MockMember(id=77, bot=True), + helpers.MockMember(id=43, roles=[self.core_dev_role]), + "bot author and core-dev reactor", + ), + ) + + for emoji in self.syncer._REACTION_EMOJIS: + for author, user, msg in user_subtests: + with self.subTest(author=author, user=user, emoji=emoji, msg=msg): + message, reaction = self.get_message_reaction(emoji) + ret_val = self.syncer._reaction_check(author, message, reaction, user) + + self.assertTrue(ret_val) + + def test_reaction_check_for_invalid_reactions(self): + """Should return False for invalid reaction events.""" + valid_emoji = self.syncer._REACTION_EMOJIS[0] + subtests = ( + ( + helpers.MockMember(id=77), + *self.get_message_reaction(valid_emoji), + helpers.MockMember(id=43, roles=[self.core_dev_role]), + "users are not identical", + ), + ( + helpers.MockMember(id=77, bot=True), + *self.get_message_reaction(valid_emoji), + helpers.MockMember(id=43), + "reactor lacks the core-dev role", + ), + ( + helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]), + *self.get_message_reaction(valid_emoji), + helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]), + "reactor is a bot", + ), + ( + helpers.MockMember(id=77), + helpers.MockMessage(id=95), + helpers.MockReaction(emoji=valid_emoji, message=helpers.MockMessage(id=26)), + helpers.MockMember(id=77), + "messages are not identical", + ), + ( + helpers.MockMember(id=77), + *self.get_message_reaction("InVaLiD"), + helpers.MockMember(id=77), + "emoji is invalid", + ), + ) + + for *args, msg in subtests: + kwargs = dict(zip(("author", "message", "reaction", "user"), args)) + with self.subTest(**kwargs, msg=msg): + ret_val = self.syncer._reaction_check(*args) + self.assertFalse(ret_val) + + async def test_wait_for_confirmation(self): + """The message should always be edited and only return True if the emoji is a check mark.""" + subtests = ( + (constants.Emojis.check_mark, True, None), + ("InVaLiD", False, None), + (None, False, TimeoutError), + ) + + for emoji, ret_val, side_effect in subtests: + for bot in (True, False): + with self.subTest(emoji=emoji, ret_val=ret_val, side_effect=side_effect, bot=bot): + # Set up mocks + message = helpers.MockMessage() + member = helpers.MockMember(bot=bot) + + self.bot.wait_for.reset_mock() + self.bot.wait_for.return_value = (helpers.MockReaction(emoji=emoji), None) + self.bot.wait_for.side_effect = side_effect + + # Call the function + actual_return = await self.syncer._wait_for_confirmation(member, message) + + # Perform assertions + self.bot.wait_for.assert_called_once() + self.assertIn("reaction_add", self.bot.wait_for.call_args[0]) + + message.edit.assert_called_once() + kwargs = message.edit.call_args[1] + self.assertIn("content", kwargs) + + # Core devs should only be mentioned if the author is a bot. + if bot: + self.assertIn(self.syncer._CORE_DEV_MENTION, kwargs["content"]) + else: + self.assertNotIn(self.syncer._CORE_DEV_MENTION, kwargs["content"]) + + self.assertIs(actual_return, ret_val) + + +class SyncerSyncTests(unittest.IsolatedAsyncioTestCase): + """Tests for main function orchestrating the sync.""" + + def setUp(self): + self.bot = helpers.MockBot(user=helpers.MockMember(bot=True)) + self.syncer = TestSyncer(self.bot) + + async def test_sync_respects_confirmation_result(self): + """The sync should abort if confirmation fails and continue if confirmed.""" + mock_message = helpers.MockMessage() + subtests = ( + (True, mock_message), + (False, None), + ) + + for confirmed, message in subtests: + with self.subTest(confirmed=confirmed): + self.syncer._sync.reset_mock() + self.syncer._get_diff.reset_mock() + + diff = _Diff({1, 2, 3}, {4, 5}, None) + self.syncer._get_diff.return_value = diff + self.syncer._get_confirmation_result = mock.AsyncMock( + return_value=(confirmed, message) + ) + + guild = helpers.MockGuild() + await self.syncer.sync(guild) + + self.syncer._get_diff.assert_called_once_with(guild) + self.syncer._get_confirmation_result.assert_called_once() + + if confirmed: + self.syncer._sync.assert_called_once_with(diff) + else: + self.syncer._sync.assert_not_called() + + async def test_sync_diff_size(self): + """The diff size should be correctly calculated.""" + subtests = ( + (6, _Diff({1, 2}, {3, 4}, {5, 6})), + (5, _Diff({1, 2, 3}, None, {4, 5})), + (0, _Diff(None, None, None)), + (0, _Diff(set(), set(), set())), + ) + + for size, diff in subtests: + with self.subTest(size=size, diff=diff): + self.syncer._get_diff.reset_mock() + self.syncer._get_diff.return_value = diff + self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None)) + + guild = helpers.MockGuild() + await self.syncer.sync(guild) + + self.syncer._get_diff.assert_called_once_with(guild) + self.syncer._get_confirmation_result.assert_called_once() + self.assertEqual(self.syncer._get_confirmation_result.call_args[0][0], size) + + async def test_sync_message_edited(self): + """The message should be edited if one was sent, even if the sync has an API error.""" + subtests = ( + (None, None, False), + (helpers.MockMessage(), None, True), + (helpers.MockMessage(), ResponseCodeError(mock.MagicMock()), True), + ) + + for message, side_effect, should_edit in subtests: + with self.subTest(message=message, side_effect=side_effect, should_edit=should_edit): + self.syncer._sync.side_effect = side_effect + self.syncer._get_confirmation_result = mock.AsyncMock( + return_value=(True, message) + ) + + guild = helpers.MockGuild() + await self.syncer.sync(guild) + + if should_edit: + message.edit.assert_called_once() + self.assertIn("content", message.edit.call_args[1]) + + async def test_sync_confirmation_context_redirect(self): + """If ctx is given, a new message should be sent and author should be ctx's author.""" + mock_member = helpers.MockMember() + subtests = ( + (None, self.bot.user, None), + (helpers.MockContext(author=mock_member), mock_member, helpers.MockMessage()), + ) + + for ctx, author, message in subtests: + with self.subTest(ctx=ctx, author=author, message=message): + if ctx is not None: + ctx.send.return_value = message + + # Make sure `_get_diff` returns a MagicMock, not an AsyncMock + self.syncer._get_diff.return_value = mock.MagicMock() + + self.syncer._get_confirmation_result = mock.AsyncMock(return_value=(False, None)) + + guild = helpers.MockGuild() + await self.syncer.sync(guild, ctx) + + if ctx is not None: + ctx.send.assert_called_once() + + self.syncer._get_confirmation_result.assert_called_once() + self.assertEqual(self.syncer._get_confirmation_result.call_args[0][1], author) + self.assertEqual(self.syncer._get_confirmation_result.call_args[0][2], message) + + @mock.patch.object(constants.Sync, "max_diff", new=3) + async def test_confirmation_result_small_diff(self): + """Should always return True and the given message if the diff size is too small.""" + author = helpers.MockMember() + expected_message = helpers.MockMessage() + + for size in (3, 2): # pragma: no cover + with self.subTest(size=size): + self.syncer._send_prompt = mock.AsyncMock() + self.syncer._wait_for_confirmation = mock.AsyncMock() + + coro = self.syncer._get_confirmation_result(size, author, expected_message) + result, actual_message = await coro + + self.assertTrue(result) + self.assertEqual(actual_message, expected_message) + self.syncer._send_prompt.assert_not_called() + self.syncer._wait_for_confirmation.assert_not_called() + + @mock.patch.object(constants.Sync, "max_diff", new=3) + async def test_confirmation_result_large_diff(self): + """Should return True if confirmed and False if _send_prompt fails or aborted.""" + author = helpers.MockMember() + mock_message = helpers.MockMessage() + + subtests = ( + (True, mock_message, True, "confirmed"), + (False, None, False, "_send_prompt failed"), + (False, mock_message, False, "aborted"), + ) + + for expected_result, expected_message, confirmed, msg in subtests: # pragma: no cover + with self.subTest(msg=msg): + self.syncer._send_prompt = mock.AsyncMock(return_value=expected_message) + self.syncer._wait_for_confirmation = mock.AsyncMock(return_value=confirmed) + + coro = self.syncer._get_confirmation_result(4, author) + actual_result, actual_message = await coro + + self.syncer._send_prompt.assert_called_once_with(None) # message defaults to None + self.assertIs(actual_result, expected_result) + self.assertEqual(actual_message, expected_message) + + if expected_message: + self.syncer._wait_for_confirmation.assert_called_once_with( + author, expected_message + ) diff --git a/tests/bot/cogs/sync/test_cog.py b/tests/bot/cogs/sync/test_cog.py new file mode 100644 index 000000000..81398c61f --- /dev/null +++ b/tests/bot/cogs/sync/test_cog.py @@ -0,0 +1,368 @@ +import unittest +from unittest import mock + +import discord + +from bot import constants +from bot.api import ResponseCodeError +from bot.cogs import sync +from bot.cogs.sync.syncers import Syncer +from tests import helpers +from tests.base import CommandTestCase + + +class SyncExtensionTests(unittest.IsolatedAsyncioTestCase): + """Tests for the sync extension.""" + + @staticmethod + def test_extension_setup(): + """The Sync cog should be added.""" + bot = helpers.MockBot() + sync.setup(bot) + bot.add_cog.assert_called_once() + + +class SyncCogTestCase(unittest.IsolatedAsyncioTestCase): + """Base class for Sync cog tests. Sets up patches for syncers.""" + + def setUp(self): + self.bot = helpers.MockBot() + + self.role_syncer_patcher = mock.patch( + "bot.cogs.sync.syncers.RoleSyncer", + autospec=Syncer, + spec_set=True + ) + self.user_syncer_patcher = mock.patch( + "bot.cogs.sync.syncers.UserSyncer", + autospec=Syncer, + spec_set=True + ) + self.RoleSyncer = self.role_syncer_patcher.start() + self.UserSyncer = self.user_syncer_patcher.start() + + self.cog = sync.Sync(self.bot) + + def tearDown(self): + self.role_syncer_patcher.stop() + self.user_syncer_patcher.stop() + + @staticmethod + def response_error(status: int) -> ResponseCodeError: + """Fixture to return a ResponseCodeError with the given status code.""" + response = mock.MagicMock() + response.status = status + + return ResponseCodeError(response) + + +class SyncCogTests(SyncCogTestCase): + """Tests for the Sync cog.""" + + @mock.patch.object(sync.Sync, "sync_guild", new_callable=mock.MagicMock) + def test_sync_cog_init(self, sync_guild): + """Should instantiate syncers and run a sync for the guild.""" + # Reset because a Sync cog was already instantiated in setUp. + self.RoleSyncer.reset_mock() + self.UserSyncer.reset_mock() + self.bot.loop.create_task = mock.MagicMock() + + mock_sync_guild_coro = mock.MagicMock() + sync_guild.return_value = mock_sync_guild_coro + + sync.Sync(self.bot) + + self.RoleSyncer.assert_called_once_with(self.bot) + self.UserSyncer.assert_called_once_with(self.bot) + sync_guild.assert_called_once_with() + self.bot.loop.create_task.assert_called_once_with(mock_sync_guild_coro) + + async def test_sync_cog_sync_guild(self): + """Roles and users should be synced only if a guild is successfully retrieved.""" + for guild in (helpers.MockGuild(), None): + with self.subTest(guild=guild): + self.bot.reset_mock() + self.cog.role_syncer.reset_mock() + self.cog.user_syncer.reset_mock() + + self.bot.get_guild = mock.MagicMock(return_value=guild) + + await self.cog.sync_guild() + + self.bot.wait_until_guild_available.assert_called_once() + self.bot.get_guild.assert_called_once_with(constants.Guild.id) + + if guild is None: + self.cog.role_syncer.sync.assert_not_called() + self.cog.user_syncer.sync.assert_not_called() + else: + self.cog.role_syncer.sync.assert_called_once_with(guild) + self.cog.user_syncer.sync.assert_called_once_with(guild) + + async def patch_user_helper(self, side_effect: BaseException) -> None: + """Helper to set a side effect for bot.api_client.patch and then assert it is called.""" + self.bot.api_client.patch.reset_mock(side_effect=True) + self.bot.api_client.patch.side_effect = side_effect + + user_id, updated_information = 5, {"key": 123} + await self.cog.patch_user(user_id, updated_information) + + self.bot.api_client.patch.assert_called_once_with( + f"bot/users/{user_id}", + json=updated_information, + ) + + async def test_sync_cog_patch_user(self): + """A PATCH request should be sent and 404 errors ignored.""" + for side_effect in (None, self.response_error(404)): + with self.subTest(side_effect=side_effect): + await self.patch_user_helper(side_effect) + + async def test_sync_cog_patch_user_non_404(self): + """A PATCH request should be sent and the error raised if it's not a 404.""" + with self.assertRaises(ResponseCodeError): + await self.patch_user_helper(self.response_error(500)) + + +class SyncCogListenerTests(SyncCogTestCase): + """Tests for the listeners of the Sync cog.""" + + def setUp(self): + super().setUp() + self.cog.patch_user = mock.AsyncMock(spec_set=self.cog.patch_user) + + async def test_sync_cog_on_guild_role_create(self): + """A POST request should be sent with the new role's data.""" + self.assertTrue(self.cog.on_guild_role_create.__cog_listener__) + + role_data = { + "colour": 49, + "id": 777, + "name": "rolename", + "permissions": 8, + "position": 23, + } + role = helpers.MockRole(**role_data) + await self.cog.on_guild_role_create(role) + + self.bot.api_client.post.assert_called_once_with("bot/roles", json=role_data) + + async def test_sync_cog_on_guild_role_delete(self): + """A DELETE request should be sent.""" + self.assertTrue(self.cog.on_guild_role_delete.__cog_listener__) + + role = helpers.MockRole(id=99) + await self.cog.on_guild_role_delete(role) + + self.bot.api_client.delete.assert_called_once_with("bot/roles/99") + + async def test_sync_cog_on_guild_role_update(self): + """A PUT request should be sent if the colour, name, permissions, or position changes.""" + self.assertTrue(self.cog.on_guild_role_update.__cog_listener__) + + role_data = { + "colour": 49, + "id": 777, + "name": "rolename", + "permissions": 8, + "position": 23, + } + subtests = ( + (True, ("colour", "name", "permissions", "position")), + (False, ("hoist", "mentionable")), + ) + + for should_put, attributes in subtests: + for attribute in attributes: + with self.subTest(should_put=should_put, changed_attribute=attribute): + self.bot.api_client.put.reset_mock() + + after_role_data = role_data.copy() + after_role_data[attribute] = 876 + + before_role = helpers.MockRole(**role_data) + after_role = helpers.MockRole(**after_role_data) + + await self.cog.on_guild_role_update(before_role, after_role) + + if should_put: + self.bot.api_client.put.assert_called_once_with( + f"bot/roles/{after_role.id}", + json=after_role_data + ) + else: + self.bot.api_client.put.assert_not_called() + + async def test_sync_cog_on_member_remove(self): + """Member should patched to set in_guild as False.""" + self.assertTrue(self.cog.on_member_remove.__cog_listener__) + + member = helpers.MockMember() + await self.cog.on_member_remove(member) + + self.cog.patch_user.assert_called_once_with( + member.id, + updated_information={"in_guild": False} + ) + + async def test_sync_cog_on_member_update_roles(self): + """Members should be patched if their roles have changed.""" + self.assertTrue(self.cog.on_member_update.__cog_listener__) + + # Roles are intentionally unsorted. + before_roles = [helpers.MockRole(id=12), helpers.MockRole(id=30), helpers.MockRole(id=20)] + before_member = helpers.MockMember(roles=before_roles) + after_member = helpers.MockMember(roles=before_roles[1:]) + + await self.cog.on_member_update(before_member, after_member) + + data = {"roles": sorted(role.id for role in after_member.roles)} + self.cog.patch_user.assert_called_once_with(after_member.id, updated_information=data) + + async def test_sync_cog_on_member_update_other(self): + """Members should not be patched if other attributes have changed.""" + self.assertTrue(self.cog.on_member_update.__cog_listener__) + + subtests = ( + ("activities", discord.Game("Pong"), discord.Game("Frogger")), + ("nick", "old nick", "new nick"), + ("status", discord.Status.online, discord.Status.offline), + ) + + for attribute, old_value, new_value in subtests: + with self.subTest(attribute=attribute): + self.cog.patch_user.reset_mock() + + before_member = helpers.MockMember(**{attribute: old_value}) + after_member = helpers.MockMember(**{attribute: new_value}) + + await self.cog.on_member_update(before_member, after_member) + + self.cog.patch_user.assert_not_called() + + async def test_sync_cog_on_user_update(self): + """A user should be patched only if the name, discriminator, or avatar changes.""" + self.assertTrue(self.cog.on_user_update.__cog_listener__) + + before_data = { + "name": "old name", + "discriminator": "1234", + "avatar": "old avatar", + "bot": False, + } + + subtests = ( + (True, "name", "name", "new name", "new name"), + (True, "discriminator", "discriminator", "8765", 8765), + (True, "avatar", "avatar_hash", "9j2e9", "9j2e9"), + (False, "bot", "bot", True, True), + ) + + for should_patch, attribute, api_field, value, api_value in subtests: + with self.subTest(attribute=attribute): + self.cog.patch_user.reset_mock() + + after_data = before_data.copy() + after_data[attribute] = value + before_user = helpers.MockUser(**before_data) + after_user = helpers.MockUser(**after_data) + + await self.cog.on_user_update(before_user, after_user) + + if should_patch: + self.cog.patch_user.assert_called_once() + + # Don't care if *all* keys are present; only the changed one is required + call_args = self.cog.patch_user.call_args + self.assertEqual(call_args[0][0], after_user.id) + self.assertIn("updated_information", call_args[1]) + + updated_information = call_args[1]["updated_information"] + self.assertIn(api_field, updated_information) + self.assertEqual(updated_information[api_field], api_value) + else: + self.cog.patch_user.assert_not_called() + + async def on_member_join_helper(self, side_effect: Exception) -> dict: + """ + Helper to set `side_effect` for on_member_join and assert a PUT request was sent. + + The request data for the mock member is returned. All exceptions will be re-raised. + """ + member = helpers.MockMember( + discriminator="1234", + roles=[helpers.MockRole(id=22), helpers.MockRole(id=12)], + ) + + data = { + "avatar_hash": member.avatar, + "discriminator": int(member.discriminator), + "id": member.id, + "in_guild": True, + "name": member.name, + "roles": sorted(role.id for role in member.roles) + } + + self.bot.api_client.put.reset_mock(side_effect=True) + self.bot.api_client.put.side_effect = side_effect + + try: + await self.cog.on_member_join(member) + except Exception: + raise + finally: + self.bot.api_client.put.assert_called_once_with( + f"bot/users/{member.id}", + json=data + ) + + return data + + async def test_sync_cog_on_member_join(self): + """Should PUT user's data or POST it if the user doesn't exist.""" + for side_effect in (None, self.response_error(404)): + with self.subTest(side_effect=side_effect): + self.bot.api_client.post.reset_mock() + data = await self.on_member_join_helper(side_effect) + + if side_effect: + self.bot.api_client.post.assert_called_once_with("bot/users", json=data) + else: + self.bot.api_client.post.assert_not_called() + + async def test_sync_cog_on_member_join_non_404(self): + """ResponseCodeError should be re-raised if status code isn't a 404.""" + with self.assertRaises(ResponseCodeError): + await self.on_member_join_helper(self.response_error(500)) + + self.bot.api_client.post.assert_not_called() + + +class SyncCogCommandTests(SyncCogTestCase, CommandTestCase): + """Tests for the commands in the Sync cog.""" + + async def test_sync_roles_command(self): + """sync() should be called on the RoleSyncer.""" + ctx = helpers.MockContext() + await self.cog.sync_roles_command.callback(self.cog, ctx) + + self.cog.role_syncer.sync.assert_called_once_with(ctx.guild, ctx) + + async def test_sync_users_command(self): + """sync() should be called on the UserSyncer.""" + ctx = helpers.MockContext() + await self.cog.sync_users_command.callback(self.cog, ctx) + + self.cog.user_syncer.sync.assert_called_once_with(ctx.guild, ctx) + + async def test_commands_require_admin(self): + """The sync commands should only run if the author has the administrator permission.""" + cmds = ( + self.cog.sync_group, + self.cog.sync_roles_command, + self.cog.sync_users_command, + ) + + for cmd in cmds: + with self.subTest(cmd=cmd): + await self.assertHasPermissionsCheck(cmd, {"administrator": True}) diff --git a/tests/bot/cogs/sync/test_roles.py b/tests/bot/cogs/sync/test_roles.py index 27ae27639..79eee98f4 100644 --- a/tests/bot/cogs/sync/test_roles.py +++ b/tests/bot/cogs/sync/test_roles.py @@ -1,126 +1,157 @@ import unittest +from unittest import mock -from bot.cogs.sync.syncers import Role, get_roles_for_sync - - -class GetRolesForSyncTests(unittest.TestCase): - """Tests constructing the roles to synchronize with the site.""" - - def test_get_roles_for_sync_empty_return_for_equal_roles(self): - """No roles should be synced when no diff is found.""" - api_roles = {Role(id=41, name='name', colour=33, permissions=0x8, position=1)} - guild_roles = {Role(id=41, name='name', colour=33, permissions=0x8, position=1)} - - self.assertEqual( - get_roles_for_sync(guild_roles, api_roles), - (set(), set(), set()) - ) - - def test_get_roles_for_sync_returns_roles_to_update_with_non_id_diff(self): - """Roles to be synced are returned when non-ID attributes differ.""" - api_roles = {Role(id=41, name='old name', colour=35, permissions=0x8, position=1)} - guild_roles = {Role(id=41, name='new name', colour=33, permissions=0x8, position=2)} - - self.assertEqual( - get_roles_for_sync(guild_roles, api_roles), - (set(), guild_roles, set()) - ) - - def test_get_roles_only_returns_roles_that_require_update(self): - """Roles that require an update should be returned as the second tuple element.""" - api_roles = { - Role(id=41, name='old name', colour=33, permissions=0x8, position=1), - Role(id=53, name='other role', colour=55, permissions=0, position=3) - } - guild_roles = { - Role(id=41, name='new name', colour=35, permissions=0x8, position=2), - Role(id=53, name='other role', colour=55, permissions=0, position=3) - } - - self.assertEqual( - get_roles_for_sync(guild_roles, api_roles), - ( - set(), - {Role(id=41, name='new name', colour=35, permissions=0x8, position=2)}, - set(), - ) - ) - - def test_get_roles_returns_new_roles_in_first_tuple_element(self): - """Newly created roles are returned as the first tuple element.""" - api_roles = { - Role(id=41, name='name', colour=35, permissions=0x8, position=1), - } - guild_roles = { - Role(id=41, name='name', colour=35, permissions=0x8, position=1), - Role(id=53, name='other role', colour=55, permissions=0, position=2) - } - - self.assertEqual( - get_roles_for_sync(guild_roles, api_roles), - ( - {Role(id=53, name='other role', colour=55, permissions=0, position=2)}, - set(), - set(), - ) - ) - - def test_get_roles_returns_roles_to_update_and_new_roles(self): - """Newly created and updated roles should be returned together.""" - api_roles = { - Role(id=41, name='old name', colour=35, permissions=0x8, position=1), - } - guild_roles = { - Role(id=41, name='new name', colour=40, permissions=0x16, position=2), - Role(id=53, name='other role', colour=55, permissions=0, position=3) - } - - self.assertEqual( - get_roles_for_sync(guild_roles, api_roles), - ( - {Role(id=53, name='other role', colour=55, permissions=0, position=3)}, - {Role(id=41, name='new name', colour=40, permissions=0x16, position=2)}, - set(), - ) - ) - - def test_get_roles_returns_roles_to_delete(self): - """Roles to be deleted should be returned as the third tuple element.""" - api_roles = { - Role(id=41, name='name', colour=35, permissions=0x8, position=1), - Role(id=61, name='to delete', colour=99, permissions=0x9, position=2), - } - guild_roles = { - Role(id=41, name='name', colour=35, permissions=0x8, position=1), - } - - self.assertEqual( - get_roles_for_sync(guild_roles, api_roles), - ( - set(), - set(), - {Role(id=61, name='to delete', colour=99, permissions=0x9, position=2)}, - ) - ) - - def test_get_roles_returns_roles_to_delete_update_and_new_roles(self): - """When roles were added, updated, and removed, all of them are returned properly.""" - api_roles = { - Role(id=41, name='not changed', colour=35, permissions=0x8, position=1), - Role(id=61, name='to delete', colour=99, permissions=0x9, position=2), - Role(id=71, name='to update', colour=99, permissions=0x9, position=3), - } - guild_roles = { - Role(id=41, name='not changed', colour=35, permissions=0x8, position=1), - Role(id=81, name='to create', colour=99, permissions=0x9, position=4), - Role(id=71, name='updated', colour=101, permissions=0x5, position=3), - } - - self.assertEqual( - get_roles_for_sync(guild_roles, api_roles), - ( - {Role(id=81, name='to create', colour=99, permissions=0x9, position=4)}, - {Role(id=71, name='updated', colour=101, permissions=0x5, position=3)}, - {Role(id=61, name='to delete', colour=99, permissions=0x9, position=2)}, - ) - ) +import discord + +from bot.cogs.sync.syncers import RoleSyncer, _Diff, _Role +from tests import helpers + + +def fake_role(**kwargs): + """Fixture to return a dictionary representing a role with default values set.""" + kwargs.setdefault("id", 9) + kwargs.setdefault("name", "fake role") + kwargs.setdefault("colour", 7) + kwargs.setdefault("permissions", 0) + kwargs.setdefault("position", 55) + + return kwargs + + +class RoleSyncerDiffTests(unittest.IsolatedAsyncioTestCase): + """Tests for determining differences between roles in the DB and roles in the Guild cache.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = RoleSyncer(self.bot) + + @staticmethod + def get_guild(*roles): + """Fixture to return a guild object with the given roles.""" + guild = helpers.MockGuild() + guild.roles = [] + + for role in roles: + mock_role = helpers.MockRole(**role) + mock_role.colour = discord.Colour(role["colour"]) + mock_role.permissions = discord.Permissions(role["permissions"]) + guild.roles.append(mock_role) + + return guild + + async def test_empty_diff_for_identical_roles(self): + """No differences should be found if the roles in the guild and DB are identical.""" + self.bot.api_client.get.return_value = [fake_role()] + guild = self.get_guild(fake_role()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), set()) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_updated_roles(self): + """Only updated roles should be added to the 'updated' set of the diff.""" + updated_role = fake_role(id=41, name="new") + + self.bot.api_client.get.return_value = [fake_role(id=41, name="old"), fake_role()] + guild = self.get_guild(updated_role, fake_role()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), {_Role(**updated_role)}, set()) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_new_roles(self): + """Only new roles should be added to the 'created' set of the diff.""" + new_role = fake_role(id=41, name="new") + + self.bot.api_client.get.return_value = [fake_role()] + guild = self.get_guild(fake_role(), new_role) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = ({_Role(**new_role)}, set(), set()) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_deleted_roles(self): + """Only deleted roles should be added to the 'deleted' set of the diff.""" + deleted_role = fake_role(id=61, name="deleted") + + self.bot.api_client.get.return_value = [fake_role(), deleted_role] + guild = self.get_guild(fake_role()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), {_Role(**deleted_role)}) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_new_updated_and_deleted_roles(self): + """When roles are added, updated, and removed, all of them are returned properly.""" + new = fake_role(id=41, name="new") + updated = fake_role(id=71, name="updated") + deleted = fake_role(id=61, name="deleted") + + self.bot.api_client.get.return_value = [ + fake_role(), + fake_role(id=71, name="updated name"), + deleted, + ] + guild = self.get_guild(fake_role(), new, updated) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = ({_Role(**new)}, {_Role(**updated)}, {_Role(**deleted)}) + + self.assertEqual(actual_diff, expected_diff) + + +class RoleSyncerSyncTests(unittest.IsolatedAsyncioTestCase): + """Tests for the API requests that sync roles.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = RoleSyncer(self.bot) + + async def test_sync_created_roles(self): + """Only POST requests should be made with the correct payload.""" + roles = [fake_role(id=111), fake_role(id=222)] + + role_tuples = {_Role(**role) for role in roles} + diff = _Diff(role_tuples, set(), set()) + await self.syncer._sync(diff) + + calls = [mock.call("bot/roles", json=role) for role in roles] + self.bot.api_client.post.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.post.call_count, len(roles)) + + self.bot.api_client.put.assert_not_called() + self.bot.api_client.delete.assert_not_called() + + async def test_sync_updated_roles(self): + """Only PUT requests should be made with the correct payload.""" + roles = [fake_role(id=111), fake_role(id=222)] + + role_tuples = {_Role(**role) for role in roles} + diff = _Diff(set(), role_tuples, set()) + await self.syncer._sync(diff) + + calls = [mock.call(f"bot/roles/{role['id']}", json=role) for role in roles] + self.bot.api_client.put.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.put.call_count, len(roles)) + + self.bot.api_client.post.assert_not_called() + self.bot.api_client.delete.assert_not_called() + + async def test_sync_deleted_roles(self): + """Only DELETE requests should be made with the correct payload.""" + roles = [fake_role(id=111), fake_role(id=222)] + + role_tuples = {_Role(**role) for role in roles} + diff = _Diff(set(), set(), role_tuples) + await self.syncer._sync(diff) + + calls = [mock.call(f"bot/roles/{role['id']}") for role in roles] + self.bot.api_client.delete.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.delete.call_count, len(roles)) + + self.bot.api_client.post.assert_not_called() + self.bot.api_client.put.assert_not_called() diff --git a/tests/bot/cogs/sync/test_users.py b/tests/bot/cogs/sync/test_users.py index ccaf67490..818883012 100644 --- a/tests/bot/cogs/sync/test_users.py +++ b/tests/bot/cogs/sync/test_users.py @@ -1,84 +1,160 @@ import unittest +from unittest import mock -from bot.cogs.sync.syncers import User, get_users_for_sync +from bot.cogs.sync.syncers import UserSyncer, _Diff, _User +from tests import helpers def fake_user(**kwargs): - kwargs.setdefault('id', 43) - kwargs.setdefault('name', 'bob the test man') - kwargs.setdefault('discriminator', 1337) - kwargs.setdefault('avatar_hash', None) - kwargs.setdefault('roles', (666,)) - kwargs.setdefault('in_guild', True) - return User(**kwargs) - - -class GetUsersForSyncTests(unittest.TestCase): - """Tests constructing the users to synchronize with the site.""" - - def test_get_users_for_sync_returns_nothing_for_empty_params(self): - """When no users are given, none are returned.""" - self.assertEqual( - get_users_for_sync({}, {}), - (set(), set()) - ) - - def test_get_users_for_sync_returns_nothing_for_equal_users(self): - """When no users are updated, none are returned.""" - api_users = {43: fake_user()} - guild_users = {43: fake_user()} - - self.assertEqual( - get_users_for_sync(guild_users, api_users), - (set(), set()) - ) - - def test_get_users_for_sync_returns_users_to_update_on_non_id_field_diff(self): - """When a non-ID-field differs, the user to update is returned.""" - api_users = {43: fake_user()} - guild_users = {43: fake_user(name='new fancy name')} - - self.assertEqual( - get_users_for_sync(guild_users, api_users), - (set(), {fake_user(name='new fancy name')}) - ) - - def test_get_users_for_sync_returns_users_to_create_with_new_ids_on_guild(self): - """When new users join the guild, they are returned as the first tuple element.""" - api_users = {43: fake_user()} - guild_users = {43: fake_user(), 63: fake_user(id=63)} - - self.assertEqual( - get_users_for_sync(guild_users, api_users), - ({fake_user(id=63)}, set()) - ) - - def test_get_users_for_sync_updates_in_guild_field_on_user_leave(self): + """Fixture to return a dictionary representing a user with default values set.""" + kwargs.setdefault("id", 43) + kwargs.setdefault("name", "bob the test man") + kwargs.setdefault("discriminator", 1337) + kwargs.setdefault("avatar_hash", None) + kwargs.setdefault("roles", (666,)) + kwargs.setdefault("in_guild", True) + + return kwargs + + +class UserSyncerDiffTests(unittest.IsolatedAsyncioTestCase): + """Tests for determining differences between users in the DB and users in the Guild cache.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = UserSyncer(self.bot) + + @staticmethod + def get_guild(*members): + """Fixture to return a guild object with the given members.""" + guild = helpers.MockGuild() + guild.members = [] + + for member in members: + member = member.copy() + member["avatar"] = member.pop("avatar_hash") + del member["in_guild"] + + mock_member = helpers.MockMember(**member) + mock_member.roles = [helpers.MockRole(id=role_id) for role_id in member["roles"]] + + guild.members.append(mock_member) + + return guild + + async def test_empty_diff_for_no_users(self): + """When no users are given, an empty diff should be returned.""" + guild = self.get_guild() + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), None) + + self.assertEqual(actual_diff, expected_diff) + + async def test_empty_diff_for_identical_users(self): + """No differences should be found if the users in the guild and DB are identical.""" + self.bot.api_client.get.return_value = [fake_user()] + guild = self.get_guild(fake_user()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), None) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_updated_users(self): + """Only updated users should be added to the 'updated' set of the diff.""" + updated_user = fake_user(id=99, name="new") + + self.bot.api_client.get.return_value = [fake_user(id=99, name="old"), fake_user()] + guild = self.get_guild(updated_user, fake_user()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), {_User(**updated_user)}, None) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_new_users(self): + """Only new users should be added to the 'created' set of the diff.""" + new_user = fake_user(id=99, name="new") + + self.bot.api_client.get.return_value = [fake_user()] + guild = self.get_guild(fake_user(), new_user) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = ({_User(**new_user)}, set(), None) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_sets_in_guild_false_for_leaving_users(self): """When a user leaves the guild, the `in_guild` flag is updated to `False`.""" - api_users = {43: fake_user(), 63: fake_user(id=63)} - guild_users = {43: fake_user()} - - self.assertEqual( - get_users_for_sync(guild_users, api_users), - (set(), {fake_user(id=63, in_guild=False)}) - ) - - def test_get_users_for_sync_updates_and_creates_users_as_needed(self): - """When one user left and another one was updated, both are returned.""" - api_users = {43: fake_user()} - guild_users = {63: fake_user(id=63)} - - self.assertEqual( - get_users_for_sync(guild_users, api_users), - ({fake_user(id=63)}, {fake_user(in_guild=False)}) - ) - - def test_get_users_for_sync_does_not_duplicate_update_users(self): - """When the API knows a user the guild doesn't, nothing is performed.""" - api_users = {43: fake_user(in_guild=False)} - guild_users = {} - - self.assertEqual( - get_users_for_sync(guild_users, api_users), - (set(), set()) - ) + leaving_user = fake_user(id=63, in_guild=False) + + self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63)] + guild = self.get_guild(fake_user()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), {_User(**leaving_user)}, None) + + self.assertEqual(actual_diff, expected_diff) + + async def test_diff_for_new_updated_and_leaving_users(self): + """When users are added, updated, and removed, all of them are returned properly.""" + new_user = fake_user(id=99, name="new") + updated_user = fake_user(id=55, name="updated") + leaving_user = fake_user(id=63, in_guild=False) + + self.bot.api_client.get.return_value = [fake_user(), fake_user(id=55), fake_user(id=63)] + guild = self.get_guild(fake_user(), new_user, updated_user) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = ({_User(**new_user)}, {_User(**updated_user), _User(**leaving_user)}, None) + + self.assertEqual(actual_diff, expected_diff) + + async def test_empty_diff_for_db_users_not_in_guild(self): + """When the DB knows a user the guild doesn't, no difference is found.""" + self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)] + guild = self.get_guild(fake_user()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), None) + + self.assertEqual(actual_diff, expected_diff) + + +class UserSyncerSyncTests(unittest.IsolatedAsyncioTestCase): + """Tests for the API requests that sync users.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = UserSyncer(self.bot) + + async def test_sync_created_users(self): + """Only POST requests should be made with the correct payload.""" + users = [fake_user(id=111), fake_user(id=222)] + + user_tuples = {_User(**user) for user in users} + diff = _Diff(user_tuples, set(), None) + await self.syncer._sync(diff) + + calls = [mock.call("bot/users", json=user) for user in users] + self.bot.api_client.post.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.post.call_count, len(users)) + + self.bot.api_client.put.assert_not_called() + self.bot.api_client.delete.assert_not_called() + + async def test_sync_updated_users(self): + """Only PUT requests should be made with the correct payload.""" + users = [fake_user(id=111), fake_user(id=222)] + + user_tuples = {_User(**user) for user in users} + diff = _Diff(set(), user_tuples, None) + await self.syncer._sync(diff) + + calls = [mock.call(f"bot/users/{user['id']}", json=user) for user in users] + self.bot.api_client.put.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.put.call_count, len(users)) + + self.bot.api_client.post.assert_not_called() + self.bot.api_client.delete.assert_not_called() diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py index d07b2bce1..7e6bfc748 100644 --- a/tests/bot/cogs/test_duck_pond.py +++ b/tests/bot/cogs/test_duck_pond.py @@ -2,7 +2,7 @@ import asyncio import logging import typing import unittest -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import discord @@ -14,7 +14,7 @@ from tests import helpers MODULE_PATH = "bot.cogs.duck_pond" -class DuckPondTests(base.LoggingTestCase): +class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase): """Tests for DuckPond functionality.""" @classmethod @@ -54,7 +54,7 @@ class DuckPondTests(base.LoggingTestCase): asyncio.run(self.cog.fetch_webhook()) - self.bot.wait_until_ready.assert_called_once() + self.bot.wait_until_guild_available.assert_called_once() self.bot.fetch_webhook.assert_called_once_with(1) self.assertEqual(self.cog.webhook, "dummy webhook") @@ -67,7 +67,7 @@ class DuckPondTests(base.LoggingTestCase): with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: asyncio.run(self.cog.fetch_webhook()) - self.bot.wait_until_ready.assert_called_once() + self.bot.wait_until_guild_available.assert_called_once() self.bot.fetch_webhook.assert_called_once_with(1) self.assertEqual(len(log_watcher.records), 1) @@ -88,7 +88,6 @@ class DuckPondTests(base.LoggingTestCase): with self.subTest(user_type=user.name, expected_return=expected_return, actual_return=actual_return): self.assertEqual(expected_return, actual_return) - @helpers.async_test async def test_has_green_checkmark_correctly_detects_presence_of_green_checkmark_emoji(self): """The `has_green_checkmark` method should only return `True` if one is present.""" test_cases = ( @@ -172,7 +171,6 @@ class DuckPondTests(base.LoggingTestCase): nonstaffers = [helpers.MockMember() for _ in range(nonstaff)] return helpers.MockReaction(emoji=emoji, users=staffers + nonstaffers) - @helpers.async_test async def test_count_ducks_correctly_counts_the_number_of_eligible_duck_emojis(self): """The `count_ducks` method should return the number of unique staffers who gave a duck.""" test_cases = ( @@ -280,7 +278,6 @@ class DuckPondTests(base.LoggingTestCase): with self.subTest(test_case=description, expected_count=expected_count, actual_count=actual_count): self.assertEqual(expected_count, actual_count) - @helpers.async_test async def test_relay_message_correctly_relays_content_and_attachments(self): """The `relay_message` method should correctly relay message content and attachments.""" send_webhook_path = f"{MODULE_PATH}.DuckPond.send_webhook" @@ -296,8 +293,8 @@ class DuckPondTests(base.LoggingTestCase): ) for message, expect_webhook_call, expect_attachment_call in test_values: - with patch(send_webhook_path, new_callable=helpers.AsyncMock) as send_webhook: - with patch(send_attachments_path, new_callable=helpers.AsyncMock) as send_attachments: + with patch(send_webhook_path, new_callable=AsyncMock) as send_webhook: + with patch(send_attachments_path, new_callable=AsyncMock) as send_attachments: with self.subTest(clean_content=message.clean_content, attachments=message.attachments): await self.cog.relay_message(message) @@ -306,8 +303,7 @@ class DuckPondTests(base.LoggingTestCase): message.add_reaction.assert_called_once_with(self.checkmark_emoji) - @patch(f"{MODULE_PATH}.send_attachments", new_callable=helpers.AsyncMock) - @helpers.async_test + @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock) async def test_relay_message_handles_irretrievable_attachment_exceptions(self, send_attachments): """The `relay_message` method should handle irretrievable attachments.""" message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) @@ -316,18 +312,17 @@ class DuckPondTests(base.LoggingTestCase): self.cog.webhook = helpers.MockAsyncWebhook() log = logging.getLogger("bot.cogs.duck_pond") - for side_effect in side_effects: + for side_effect in side_effects: # pragma: no cover send_attachments.side_effect = side_effect - with patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=helpers.AsyncMock) as send_webhook: + with patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=AsyncMock) as send_webhook: with self.subTest(side_effect=type(side_effect).__name__): with self.assertNotLogs(logger=log, level=logging.ERROR): await self.cog.relay_message(message) self.assertEqual(send_webhook.call_count, 2) - @patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=helpers.AsyncMock) - @patch(f"{MODULE_PATH}.send_attachments", new_callable=helpers.AsyncMock) - @helpers.async_test + @patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=AsyncMock) + @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock) async def test_relay_message_handles_attachment_http_error(self, send_attachments, send_webhook): """The `relay_message` method should handle irretrievable attachments.""" message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) @@ -360,7 +355,6 @@ class DuckPondTests(base.LoggingTestCase): payload.emoji.name = emoji_name return payload - @helpers.async_test async def test_payload_has_duckpond_emoji_correctly_detects_relevant_emojis(self): """The `on_raw_reaction_add` event handler should ignore irrelevant emojis.""" test_values = ( @@ -434,7 +428,6 @@ class DuckPondTests(base.LoggingTestCase): return channel, message, member, payload - @helpers.async_test async def test_on_raw_reaction_add_returns_for_bot_and_non_staff_members(self): """The `on_raw_reaction_add` event handler should return for bot users or non-staff members.""" channel_id = 1234 @@ -463,7 +456,7 @@ class DuckPondTests(base.LoggingTestCase): channel.fetch_message.reset_mock() @patch(f"{MODULE_PATH}.DuckPond.is_staff") - @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) + @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) def test_on_raw_reaction_add_returns_on_message_with_green_checkmark_placed_by_bot(self, count_ducks, is_staff): """The `on_raw_reaction_add` event should return when the message has a green check mark placed by the bot.""" channel_id = 31415926535 @@ -485,7 +478,6 @@ class DuckPondTests(base.LoggingTestCase): # Assert that we've made it past `self.is_staff` is_staff.assert_called_once() - @helpers.async_test async def test_on_raw_reaction_add_does_not_relay_below_duck_threshold(self): """The `on_raw_reaction_add` listener should not relay messages or attachments below the duck threshold.""" test_cases = ( @@ -499,8 +491,8 @@ class DuckPondTests(base.LoggingTestCase): payload.emoji = self.duck_pond_emoji for duck_count, should_relay in test_cases: - with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=helpers.AsyncMock) as relay_message: - with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) as count_ducks: + with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=AsyncMock) as relay_message: + with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks: count_ducks.return_value = duck_count with self.subTest(duck_count=duck_count, should_relay=should_relay): await self.cog.on_raw_reaction_add(payload) @@ -515,7 +507,6 @@ class DuckPondTests(base.LoggingTestCase): if should_relay: relay_message.assert_called_once_with(message) - @helpers.async_test async def test_on_raw_reaction_remove_prevents_removal_of_green_checkmark_depending_on_the_duck_count(self): """The `on_raw_reaction_remove` listener prevents removal of the check mark on messages with enough ducks.""" checkmark = helpers.MockPartialEmoji(name=self.checkmark_emoji) @@ -535,7 +526,7 @@ class DuckPondTests(base.LoggingTestCase): (constants.DuckPond.threshold + 1, True), ) for duck_count, should_re_add_checkmark in test_cases: - with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) as count_ducks: + with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks: count_ducks.return_value = duck_count with self.subTest(duck_count=duck_count, should_re_add_checkmark=should_re_add_checkmark): await self.cog.on_raw_reaction_remove(payload) diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py index 4496a2ae0..3c26374f5 100644 --- a/tests/bot/cogs/test_information.py +++ b/tests/bot/cogs/test_information.py @@ -19,7 +19,7 @@ class InformationCogTests(unittest.TestCase): @classmethod def setUpClass(cls): - cls.moderator_role = helpers.MockRole(name="Moderator", id=constants.Roles.moderator) + cls.moderator_role = helpers.MockRole(name="Moderator", id=constants.Roles.moderators) def setUp(self): """Sets up fresh objects for each test.""" @@ -34,7 +34,7 @@ class InformationCogTests(unittest.TestCase): """Test if the `role_info` command correctly returns the `moderator_role`.""" self.ctx.guild.roles.append(self.moderator_role) - self.cog.roles_info.can_run = helpers.AsyncMock() + self.cog.roles_info.can_run = unittest.mock.AsyncMock() self.cog.roles_info.can_run.return_value = True coroutine = self.cog.roles_info.callback(self.cog, self.ctx) @@ -45,10 +45,9 @@ class InformationCogTests(unittest.TestCase): _, kwargs = self.ctx.send.call_args embed = kwargs.pop('embed') - self.assertEqual(embed.title, "Role information") + self.assertEqual(embed.title, "Role information (Total 1 role)") self.assertEqual(embed.colour, discord.Colour.blurple()) - self.assertEqual(embed.description, f"`{self.moderator_role.id}` - {self.moderator_role.mention}\n") - self.assertEqual(embed.footer.text, "Total roles: 1") + self.assertEqual(embed.description, f"\n`{self.moderator_role.id}` - {self.moderator_role.mention}\n") def test_role_info_command(self): """Tests the `role info` command.""" @@ -72,7 +71,7 @@ class InformationCogTests(unittest.TestCase): self.ctx.guild.roles.append([dummy_role, admin_role]) - self.cog.role_info.can_run = helpers.AsyncMock() + self.cog.role_info.can_run = unittest.mock.AsyncMock() self.cog.role_info.can_run.return_value = True coroutine = self.cog.role_info.callback(self.cog, self.ctx, dummy_role, admin_role) @@ -125,10 +124,10 @@ class InformationCogTests(unittest.TestCase): ) ], members=[ - *(helpers.MockMember(status='online') for _ in range(2)), - *(helpers.MockMember(status='idle') for _ in range(1)), - *(helpers.MockMember(status='dnd') for _ in range(4)), - *(helpers.MockMember(status='offline') for _ in range(3)), + *(helpers.MockMember(status=discord.Status.online) for _ in range(2)), + *(helpers.MockMember(status=discord.Status.idle) for _ in range(1)), + *(helpers.MockMember(status=discord.Status.dnd) for _ in range(4)), + *(helpers.MockMember(status=discord.Status.offline) for _ in range(3)), ], member_count=1_234, icon_url='a-lemon.jpg', @@ -153,9 +152,9 @@ class InformationCogTests(unittest.TestCase): **Counts** Members: {self.ctx.guild.member_count:,} Roles: {len(self.ctx.guild.roles)} - Text: 1 - Voice: 1 - Channel categories: 1 + Category channels: 1 + Text channels: 1 + Voice channels: 1 **Members** {constants.Emojis.status_online} 2 @@ -174,7 +173,7 @@ class UserInfractionHelperMethodTests(unittest.TestCase): def setUp(self): """Common set-up steps done before for each test.""" self.bot = helpers.MockBot() - self.bot.api_client.get = helpers.AsyncMock() + self.bot.api_client.get = unittest.mock.AsyncMock() self.cog = information.Information(self.bot) self.member = helpers.MockMember(id=1234) @@ -345,10 +344,10 @@ class UserEmbedTests(unittest.TestCase): def setUp(self): """Common set-up steps done before for each test.""" self.bot = helpers.MockBot() - self.bot.api_client.get = helpers.AsyncMock() + self.bot.api_client.get = unittest.mock.AsyncMock() self.cog = information.Information(self.bot) - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) def test_create_user_embed_uses_string_representation_of_user_in_title_if_nick_is_not_available(self): """The embed should use the string representation of the user if they don't have a nick.""" ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) @@ -360,7 +359,7 @@ class UserEmbedTests(unittest.TestCase): self.assertEqual(embed.title, "Mr. Hemlock") - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) def test_create_user_embed_uses_nick_in_title_if_available(self): """The embed should use the nick if it's available.""" ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) @@ -372,7 +371,7 @@ class UserEmbedTests(unittest.TestCase): self.assertEqual(embed.title, "Cat lover (Mr. Hemlock)") - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) def test_create_user_embed_ignores_everyone_role(self): """Created `!user` embeds should not contain mention of the @everyone-role.""" ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) @@ -387,8 +386,8 @@ class UserEmbedTests(unittest.TestCase): self.assertIn("&Admins", embed.description) self.assertNotIn("&Everyone", embed.description) - @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=helpers.AsyncMock) - @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=helpers.AsyncMock) + @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=unittest.mock.AsyncMock) + @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=unittest.mock.AsyncMock) def test_create_user_embed_expanded_information_in_moderation_channels(self, nomination_counts, infraction_counts): """The embed should contain expanded infractions and nomination info in mod channels.""" ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=50)) @@ -423,7 +422,7 @@ class UserEmbedTests(unittest.TestCase): embed.description ) - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=helpers.AsyncMock) + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=unittest.mock.AsyncMock) def test_create_user_embed_basic_information_outside_of_moderation_channels(self, infraction_counts): """The embed should contain only basic infraction data outside of mod channels.""" ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=100)) @@ -454,7 +453,7 @@ class UserEmbedTests(unittest.TestCase): embed.description ) - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) def test_create_user_embed_uses_top_role_colour_when_user_has_roles(self): """The embed should be created with the colour of the top role, if a top role is available.""" ctx = helpers.MockContext() @@ -467,7 +466,7 @@ class UserEmbedTests(unittest.TestCase): self.assertEqual(embed.colour, discord.Colour(moderators_role.colour)) - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) def test_create_user_embed_uses_blurple_colour_when_user_has_no_roles(self): """The embed should be created with a blurple colour if the user has no assigned roles.""" ctx = helpers.MockContext() @@ -477,7 +476,7 @@ class UserEmbedTests(unittest.TestCase): self.assertEqual(embed.colour, discord.Colour.blurple()) - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value="")) + @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) def test_create_user_embed_uses_png_format_of_user_avatar_as_thumbnail(self): """The embed thumbnail should be set to the user's avatar in `png` format.""" ctx = helpers.MockContext() @@ -521,7 +520,7 @@ class UserCommandTests(unittest.TestCase): """A regular user should not be able to use this command outside of bot-commands.""" constants.MODERATION_ROLES = [self.moderator_role.id] constants.STAFF_ROLES = [self.moderator_role.id] - constants.Channels.bot = 50 + constants.Channels.bot_commands = 50 ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=100)) @@ -529,11 +528,11 @@ class UserCommandTests(unittest.TestCase): with self.assertRaises(InChannelCheckFailure, msg=msg): asyncio.run(self.cog.user_info.callback(self.cog, ctx)) - @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock) + @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock) def test_regular_user_may_use_command_in_bot_commands_channel(self, create_embed, constants): """A regular user should be allowed to use `!user` targeting themselves in bot-commands.""" constants.STAFF_ROLES = [self.moderator_role.id] - constants.Channels.bot = 50 + constants.Channels.bot_commands = 50 ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=50)) @@ -542,11 +541,11 @@ class UserCommandTests(unittest.TestCase): create_embed.assert_called_once_with(ctx, self.author) ctx.send.assert_called_once() - @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock) + @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock) def test_regular_user_can_explicitly_target_themselves(self, create_embed, constants): """A user should target itself with `!user` when a `user` argument was not provided.""" constants.STAFF_ROLES = [self.moderator_role.id] - constants.Channels.bot = 50 + constants.Channels.bot_commands = 50 ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=50)) @@ -555,11 +554,11 @@ class UserCommandTests(unittest.TestCase): create_embed.assert_called_once_with(ctx, self.author) ctx.send.assert_called_once() - @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock) + @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock) def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants): """Staff members should be able to bypass the bot-commands channel restriction.""" constants.STAFF_ROLES = [self.moderator_role.id] - constants.Channels.bot = 50 + constants.Channels.bot_commands = 50 ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=200)) @@ -568,7 +567,7 @@ class UserCommandTests(unittest.TestCase): create_embed.assert_called_once_with(ctx, self.moderator) ctx.send.assert_called_once() - @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock) + @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=unittest.mock.AsyncMock) def test_moderators_can_target_another_member(self, create_embed, constants): """A moderator should be able to use `!user` targeting another user.""" constants.MODERATION_ROLES = [self.moderator_role.id] diff --git a/tests/bot/cogs/test_snekbox.py b/tests/bot/cogs/test_snekbox.py new file mode 100644 index 000000000..fd9468829 --- /dev/null +++ b/tests/bot/cogs/test_snekbox.py @@ -0,0 +1,354 @@ +import asyncio +import logging +import unittest +from unittest.mock import AsyncMock, MagicMock, Mock, call, patch + +from bot.cogs import snekbox +from bot.cogs.snekbox import Snekbox +from bot.constants import URLs +from tests.helpers import MockBot, MockContext, MockMessage, MockReaction, MockUser + + +class SnekboxTests(unittest.IsolatedAsyncioTestCase): + def setUp(self): + """Add mocked bot and cog to the instance.""" + self.bot = MockBot() + self.cog = Snekbox(bot=self.bot) + + async def test_post_eval(self): + """Post the eval code to the URLs.snekbox_eval_api endpoint.""" + resp = MagicMock() + resp.json = AsyncMock(return_value="return") + self.bot.http_session.post().__aenter__.return_value = resp + + self.assertEqual(await self.cog.post_eval("import random"), "return") + self.bot.http_session.post.assert_called_with( + URLs.snekbox_eval_api, + json={"input": "import random"}, + raise_for_status=True + ) + resp.json.assert_awaited_once() + + async def test_upload_output_reject_too_long(self): + """Reject output longer than MAX_PASTE_LEN.""" + result = await self.cog.upload_output("-" * (snekbox.MAX_PASTE_LEN + 1)) + self.assertEqual(result, "too long to upload") + + async def test_upload_output(self): + """Upload the eval output to the URLs.paste_service.format(key="documents") endpoint.""" + key = "MarkDiamond" + resp = MagicMock() + resp.json = AsyncMock(return_value={"key": key}) + self.bot.http_session.post().__aenter__.return_value = resp + + self.assertEqual( + await self.cog.upload_output("My awesome output"), + URLs.paste_service.format(key=key) + ) + self.bot.http_session.post.assert_called_with( + URLs.paste_service.format(key="documents"), + data="My awesome output", + raise_for_status=True + ) + + async def test_upload_output_gracefully_fallback_if_exception_during_request(self): + """Output upload gracefully fallback if the upload fail.""" + resp = MagicMock() + resp.json = AsyncMock(side_effect=Exception) + self.bot.http_session.post().__aenter__.return_value = resp + + log = logging.getLogger("bot.cogs.snekbox") + with self.assertLogs(logger=log, level='ERROR'): + await self.cog.upload_output('My awesome output!') + + async def test_upload_output_gracefully_fallback_if_no_key_in_response(self): + """Output upload gracefully fallback if there is no key entry in the response body.""" + self.assertEqual((await self.cog.upload_output('My awesome output!')), None) + + def test_prepare_input(self): + cases = ( + ('print("Hello world!")', 'print("Hello world!")', 'non-formatted'), + ('`print("Hello world!")`', 'print("Hello world!")', 'one line code block'), + ('```\nprint("Hello world!")```', 'print("Hello world!")', 'multiline code block'), + ('```py\nprint("Hello world!")```', 'print("Hello world!")', 'multiline python code block'), + ) + for case, expected, testname in cases: + with self.subTest(msg=f'Extract code from {testname}.'): + self.assertEqual(self.cog.prepare_input(case), expected) + + def test_get_results_message(self): + """Return error and message according to the eval result.""" + cases = ( + ('ERROR', None, ('Your eval job has failed', 'ERROR')), + ('', 128 + snekbox.SIGKILL, ('Your eval job timed out or ran out of memory', '')), + ('', 255, ('Your eval job has failed', 'A fatal NsJail error occurred')) + ) + for stdout, returncode, expected in cases: + with self.subTest(stdout=stdout, returncode=returncode, expected=expected): + actual = self.cog.get_results_message({'stdout': stdout, 'returncode': returncode}) + self.assertEqual(actual, expected) + + @patch('bot.cogs.snekbox.Signals', side_effect=ValueError) + def test_get_results_message_invalid_signal(self, mock_signals: Mock): + self.assertEqual( + self.cog.get_results_message({'stdout': '', 'returncode': 127}), + ('Your eval job has completed with return code 127', '') + ) + + @patch('bot.cogs.snekbox.Signals') + def test_get_results_message_valid_signal(self, mock_signals: Mock): + mock_signals.return_value.name = 'SIGTEST' + self.assertEqual( + self.cog.get_results_message({'stdout': '', 'returncode': 127}), + ('Your eval job has completed with return code 127 (SIGTEST)', '') + ) + + def test_get_status_emoji(self): + """Return emoji according to the eval result.""" + cases = ( + (' ', -1, ':warning:'), + ('Hello world!', 0, ':white_check_mark:'), + ('Invalid beard size', -1, ':x:') + ) + for stdout, returncode, expected in cases: + with self.subTest(stdout=stdout, returncode=returncode, expected=expected): + actual = self.cog.get_status_emoji({'stdout': stdout, 'returncode': returncode}) + self.assertEqual(actual, expected) + + async def test_format_output(self): + """Test output formatting.""" + self.cog.upload_output = AsyncMock(return_value='https://testificate.com/') + + too_many_lines = ( + '001 | v\n002 | e\n003 | r\n004 | y\n005 | l\n006 | o\n' + '007 | n\n008 | g\n009 | b\n010 | e\n011 | a\n... (truncated - too many lines)' + ) + too_long_too_many_lines = ( + "\n".join( + f"{i:03d} | {line}" for i, line in enumerate(['verylongbeard' * 10] * 15, 1) + )[:1000] + "\n... (truncated - too long, too many lines)" + ) + + cases = ( + ('', ('[No output]', None), 'No output'), + ('My awesome output', ('My awesome output', None), 'One line output'), + ('<@', ("<@\u200B", None), r'Convert <@ to <@\u200B'), + ('<!@', ("<!@\u200B", None), r'Convert <!@ to <!@\u200B'), + ( + '\u202E\u202E\u202E', + ('Code block escape attempt detected; will not output result', None), + 'Detect RIGHT-TO-LEFT OVERRIDE' + ), + ( + '\u200B\u200B\u200B', + ('Code block escape attempt detected; will not output result', None), + 'Detect ZERO WIDTH SPACE' + ), + ('long\nbeard', ('001 | long\n002 | beard', None), 'Two line output'), + ( + 'v\ne\nr\ny\nl\no\nn\ng\nb\ne\na\nr\nd', + (too_many_lines, 'https://testificate.com/'), + '12 lines output' + ), + ( + 'verylongbeard' * 100, + ('verylongbeard' * 76 + 'verylongbear\n... (truncated - too long)', 'https://testificate.com/'), + '1300 characters output' + ), + ( + ('verylongbeard' * 10 + '\n') * 15, + (too_long_too_many_lines, 'https://testificate.com/'), + '15 lines, 1965 characters output' + ), + ) + for case, expected, testname in cases: + with self.subTest(msg=testname, case=case, expected=expected): + self.assertEqual(await self.cog.format_output(case), expected) + + async def test_eval_command_evaluate_once(self): + """Test the eval command procedure.""" + ctx = MockContext() + response = MockMessage() + self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode') + self.cog.send_eval = AsyncMock(return_value=response) + self.cog.continue_eval = AsyncMock(return_value=None) + + await self.cog.eval_command.callback(self.cog, ctx=ctx, code='MyAwesomeCode') + self.cog.prepare_input.assert_called_once_with('MyAwesomeCode') + self.cog.send_eval.assert_called_once_with(ctx, 'MyAwesomeFormattedCode') + self.cog.continue_eval.assert_called_once_with(ctx, response) + + async def test_eval_command_evaluate_twice(self): + """Test the eval and re-eval command procedure.""" + ctx = MockContext() + response = MockMessage() + self.cog.prepare_input = MagicMock(return_value='MyAwesomeFormattedCode') + self.cog.send_eval = AsyncMock(return_value=response) + self.cog.continue_eval = AsyncMock() + self.cog.continue_eval.side_effect = ('MyAwesomeCode-2', None) + + await self.cog.eval_command.callback(self.cog, ctx=ctx, code='MyAwesomeCode') + self.cog.prepare_input.has_calls(call('MyAwesomeCode'), call('MyAwesomeCode-2')) + self.cog.send_eval.assert_called_with(ctx, 'MyAwesomeFormattedCode') + self.cog.continue_eval.assert_called_with(ctx, response) + + async def test_eval_command_reject_two_eval_at_the_same_time(self): + """Test if the eval command rejects an eval if the author already have a running eval.""" + ctx = MockContext() + ctx.author.id = 42 + ctx.author.mention = '@LemonLemonishBeard#0042' + ctx.send = AsyncMock() + self.cog.jobs = (42,) + await self.cog.eval_command.callback(self.cog, ctx=ctx, code='MyAwesomeCode') + ctx.send.assert_called_once_with( + "@LemonLemonishBeard#0042 You've already got a job running - please wait for it to finish!" + ) + + async def test_eval_command_call_help(self): + """Test if the eval command call the help command if no code is provided.""" + ctx = MockContext() + ctx.invoke = AsyncMock() + await self.cog.eval_command.callback(self.cog, ctx=ctx, code='') + ctx.invoke.assert_called_once_with(self.bot.get_command("help"), "eval") + + async def test_send_eval(self): + """Test the send_eval function.""" + ctx = MockContext() + ctx.message = MockMessage() + ctx.send = AsyncMock() + ctx.author.mention = '@LemonLemonishBeard#0042' + + self.cog.post_eval = AsyncMock(return_value={'stdout': '', 'returncode': 0}) + self.cog.get_results_message = MagicMock(return_value=('Return code 0', '')) + self.cog.get_status_emoji = MagicMock(return_value=':yay!:') + self.cog.format_output = AsyncMock(return_value=('[No output]', None)) + + await self.cog.send_eval(ctx, 'MyAwesomeCode') + ctx.send.assert_called_once_with( + '@LemonLemonishBeard#0042 :yay!: Return code 0.\n\n```py\n[No output]\n```' + ) + self.cog.post_eval.assert_called_once_with('MyAwesomeCode') + self.cog.get_status_emoji.assert_called_once_with({'stdout': '', 'returncode': 0}) + self.cog.get_results_message.assert_called_once_with({'stdout': '', 'returncode': 0}) + self.cog.format_output.assert_called_once_with('') + + async def test_send_eval_with_paste_link(self): + """Test the send_eval function with a too long output that generate a paste link.""" + ctx = MockContext() + ctx.message = MockMessage() + ctx.send = AsyncMock() + ctx.author.mention = '@LemonLemonishBeard#0042' + + self.cog.post_eval = AsyncMock(return_value={'stdout': 'Way too long beard', 'returncode': 0}) + self.cog.get_results_message = MagicMock(return_value=('Return code 0', '')) + self.cog.get_status_emoji = MagicMock(return_value=':yay!:') + self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com')) + + await self.cog.send_eval(ctx, 'MyAwesomeCode') + ctx.send.assert_called_once_with( + '@LemonLemonishBeard#0042 :yay!: Return code 0.' + '\n\n```py\nWay too long beard\n```\nFull output: lookatmybeard.com' + ) + self.cog.post_eval.assert_called_once_with('MyAwesomeCode') + self.cog.get_status_emoji.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) + self.cog.get_results_message.assert_called_once_with({'stdout': 'Way too long beard', 'returncode': 0}) + self.cog.format_output.assert_called_once_with('Way too long beard') + + async def test_send_eval_with_non_zero_eval(self): + """Test the send_eval function with a code returning a non-zero code.""" + ctx = MockContext() + ctx.message = MockMessage() + ctx.send = AsyncMock() + ctx.author.mention = '@LemonLemonishBeard#0042' + self.cog.post_eval = AsyncMock(return_value={'stdout': 'ERROR', 'returncode': 127}) + self.cog.get_results_message = MagicMock(return_value=('Return code 127', 'Beard got stuck in the eval')) + self.cog.get_status_emoji = MagicMock(return_value=':nope!:') + self.cog.format_output = AsyncMock() # This function isn't called + + await self.cog.send_eval(ctx, 'MyAwesomeCode') + ctx.send.assert_called_once_with( + '@LemonLemonishBeard#0042 :nope!: Return code 127.\n\n```py\nBeard got stuck in the eval\n```' + ) + self.cog.post_eval.assert_called_once_with('MyAwesomeCode') + self.cog.get_status_emoji.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) + self.cog.get_results_message.assert_called_once_with({'stdout': 'ERROR', 'returncode': 127}) + self.cog.format_output.assert_not_called() + + @patch("bot.cogs.snekbox.partial") + async def test_continue_eval_does_continue(self, partial_mock): + """Test that the continue_eval function does continue if required conditions are met.""" + ctx = MockContext(message=MockMessage(add_reaction=AsyncMock(), clear_reactions=AsyncMock())) + response = MockMessage(delete=AsyncMock()) + new_msg = MockMessage(content='!e NewCode') + self.bot.wait_for.side_effect = ((None, new_msg), None) + + actual = await self.cog.continue_eval(ctx, response) + self.assertEqual(actual, 'NewCode') + self.bot.wait_for.assert_has_awaits( + ( + call('message_edit', check=partial_mock(snekbox.predicate_eval_message_edit, ctx), timeout=10), + call('reaction_add', check=partial_mock(snekbox.predicate_eval_emoji_reaction, ctx), timeout=10) + ) + ) + ctx.message.add_reaction.assert_called_once_with(snekbox.REEVAL_EMOJI) + ctx.message.clear_reactions.assert_called_once() + response.delete.assert_called_once() + + async def test_continue_eval_does_not_continue(self): + ctx = MockContext(message=MockMessage(clear_reactions=AsyncMock())) + self.bot.wait_for.side_effect = asyncio.TimeoutError + + actual = await self.cog.continue_eval(ctx, MockMessage()) + self.assertEqual(actual, None) + ctx.message.clear_reactions.assert_called_once() + + def test_predicate_eval_message_edit(self): + """Test the predicate_eval_message_edit function.""" + msg0 = MockMessage(id=1, content='abc') + msg1 = MockMessage(id=2, content='abcdef') + msg2 = MockMessage(id=1, content='abcdef') + + cases = ( + (msg0, msg0, False, 'same ID, same content'), + (msg0, msg1, False, 'different ID, different content'), + (msg0, msg2, True, 'same ID, different content') + ) + for ctx_msg, new_msg, expected, testname in cases: + with self.subTest(msg=f'Messages with {testname} return {expected}'): + ctx = MockContext(message=ctx_msg) + actual = snekbox.predicate_eval_message_edit(ctx, ctx_msg, new_msg) + self.assertEqual(actual, expected) + + def test_predicate_eval_emoji_reaction(self): + """Test the predicate_eval_emoji_reaction function.""" + valid_reaction = MockReaction(message=MockMessage(id=1)) + valid_reaction.__str__.return_value = snekbox.REEVAL_EMOJI + valid_ctx = MockContext(message=MockMessage(id=1), author=MockUser(id=2)) + valid_user = MockUser(id=2) + + invalid_reaction_id = MockReaction(message=MockMessage(id=42)) + invalid_reaction_id.__str__.return_value = snekbox.REEVAL_EMOJI + invalid_user_id = MockUser(id=42) + invalid_reaction_str = MockReaction(message=MockMessage(id=1)) + invalid_reaction_str.__str__.return_value = ':longbeard:' + + cases = ( + (invalid_reaction_id, valid_user, False, 'invalid reaction ID'), + (valid_reaction, invalid_user_id, False, 'invalid user ID'), + (invalid_reaction_str, valid_user, False, 'invalid reaction __str__'), + (valid_reaction, valid_user, True, 'matching attributes') + ) + for reaction, user, expected, testname in cases: + with self.subTest(msg=f'Test with {testname} and expected return {expected}'): + actual = snekbox.predicate_eval_emoji_reaction(valid_ctx, reaction, user) + self.assertEqual(actual, expected) + + +class SnekboxSetupTests(unittest.TestCase): + """Tests setup of the `Snekbox` cog.""" + + def test_setup(self): + """Setup of the extension should call add_cog.""" + bot = MockBot() + snekbox.setup(bot) + bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index a54b839d7..33d1ec170 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -1,7 +1,7 @@ import asyncio import logging import unittest -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock from discord import Colour @@ -11,7 +11,7 @@ from bot.cogs.token_remover import ( setup as setup_cog, ) from bot.constants import Channels, Colours, Event, Icons -from tests.helpers import AsyncMock, MockBot, MockMessage +from tests.helpers import MockBot, MockMessage class TokenRemoverTests(unittest.TestCase): diff --git a/tests/bot/rules/__init__.py b/tests/bot/rules/__init__.py index e69de29bb..0d570f5a3 100644 --- a/tests/bot/rules/__init__.py +++ b/tests/bot/rules/__init__.py @@ -0,0 +1,76 @@ +import unittest +from abc import ABCMeta, abstractmethod +from typing import Callable, Dict, Iterable, List, NamedTuple, Tuple + +from tests.helpers import MockMessage + + +class DisallowedCase(NamedTuple): + """Encapsulation for test cases expected to fail.""" + recent_messages: List[MockMessage] + culprits: Iterable[str] + n_violations: int + + +class RuleTest(unittest.IsolatedAsyncioTestCase, metaclass=ABCMeta): + """ + Abstract class for antispam rule test cases. + + Tests for specific rules should inherit from `RuleTest` and implement + `relevant_messages` and `get_report`. Each instance should also set the + `apply` and `config` attributes as necessary. + + The execution of test cases can then be delegated to the `run_allowed` + and `run_disallowed` methods. + """ + + apply: Callable # The tested rule's apply function + config: Dict[str, int] + + async def run_allowed(self, cases: Tuple[List[MockMessage], ...]) -> None: + """Run all `cases` against `self.apply` expecting them to pass.""" + for recent_messages in cases: + last_message = recent_messages[0] + + with self.subTest( + last_message=last_message, + recent_messages=recent_messages, + config=self.config, + ): + self.assertIsNone( + await self.apply(last_message, recent_messages, self.config) + ) + + async def run_disallowed(self, cases: Tuple[DisallowedCase, ...]) -> None: + """Run all `cases` against `self.apply` expecting them to fail.""" + for case in cases: + recent_messages, culprits, n_violations = case + last_message = recent_messages[0] + relevant_messages = self.relevant_messages(case) + desired_output = ( + self.get_report(case), + culprits, + relevant_messages, + ) + + with self.subTest( + last_message=last_message, + recent_messages=recent_messages, + relevant_messages=relevant_messages, + n_violations=n_violations, + config=self.config, + ): + self.assertTupleEqual( + await self.apply(last_message, recent_messages, self.config), + desired_output, + ) + + @abstractmethod + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + """Give expected relevant messages for `case`.""" + raise NotImplementedError # pragma: no cover + + @abstractmethod + def get_report(self, case: DisallowedCase) -> str: + """Give expected error report for `case`.""" + raise NotImplementedError # pragma: no cover diff --git a/tests/bot/rules/test_attachments.py b/tests/bot/rules/test_attachments.py index d7187f315..d7e779221 100644 --- a/tests/bot/rules/test_attachments.py +++ b/tests/bot/rules/test_attachments.py @@ -1,98 +1,69 @@ -import unittest -from typing import List, NamedTuple, Tuple +from typing import Iterable from bot.rules import attachments -from tests.helpers import MockMessage, async_test +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage -class Case(NamedTuple): - recent_messages: List[MockMessage] - culprit: Tuple[str] - total_attachments: int - - -def msg(author: str, total_attachments: int) -> MockMessage: +def make_msg(author: str, total_attachments: int) -> MockMessage: """Builds a message with `total_attachments` attachments.""" return MockMessage(author=author, attachments=list(range(total_attachments))) -class AttachmentRuleTests(unittest.TestCase): +class AttachmentRuleTests(RuleTest): """Tests applying the `attachments` antispam rule.""" def setUp(self): - self.config = {"max": 5} + self.apply = attachments.apply + self.config = {"max": 5, "interval": 10} - @async_test async def test_allows_messages_without_too_many_attachments(self): """Messages without too many attachments are allowed as-is.""" cases = ( - [msg("bob", 0), msg("bob", 0), msg("bob", 0)], - [msg("bob", 2), msg("bob", 2)], - [msg("bob", 2), msg("alice", 2), msg("bob", 2)], + [make_msg("bob", 0), make_msg("bob", 0), make_msg("bob", 0)], + [make_msg("bob", 2), make_msg("bob", 2)], + [make_msg("bob", 2), make_msg("alice", 2), make_msg("bob", 2)], ) - for recent_messages in cases: - last_message = recent_messages[0] - - with self.subTest( - last_message=last_message, - recent_messages=recent_messages, - config=self.config - ): - self.assertIsNone( - await attachments.apply(last_message, recent_messages, self.config) - ) + await self.run_allowed(cases) - @async_test async def test_disallows_messages_with_too_many_attachments(self): """Messages with too many attachments trigger the rule.""" cases = ( - Case( - [msg("bob", 4), msg("bob", 0), msg("bob", 6)], + DisallowedCase( + [make_msg("bob", 4), make_msg("bob", 0), make_msg("bob", 6)], ("bob",), - 10 + 10, ), - Case( - [msg("bob", 4), msg("alice", 6), msg("bob", 2)], + DisallowedCase( + [make_msg("bob", 4), make_msg("alice", 6), make_msg("bob", 2)], ("bob",), - 6 + 6, ), - Case( - [msg("alice", 6)], + DisallowedCase( + [make_msg("alice", 6)], ("alice",), - 6 + 6, ), - ( - [msg("alice", 1) for _ in range(6)], + DisallowedCase( + [make_msg("alice", 1) for _ in range(6)], ("alice",), - 6 + 6, ), ) - for recent_messages, culprit, total_attachments in cases: - last_message = recent_messages[0] - relevant_messages = tuple( - msg - for msg in recent_messages - if ( - msg.author == last_message.author - and len(msg.attachments) > 0 - ) + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + last_message = case.recent_messages[0] + return tuple( + msg + for msg in case.recent_messages + if ( + msg.author == last_message.author + and len(msg.attachments) > 0 ) + ) - with self.subTest( - last_message=last_message, - recent_messages=recent_messages, - relevant_messages=relevant_messages, - total_attachments=total_attachments, - config=self.config - ): - desired_output = ( - f"sent {total_attachments} attachments in {self.config['max']}s", - culprit, - relevant_messages - ) - self.assertTupleEqual( - await attachments.apply(last_message, recent_messages, self.config), - desired_output - ) + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} attachments in {self.config['interval']}s" diff --git a/tests/bot/rules/test_burst.py b/tests/bot/rules/test_burst.py new file mode 100644 index 000000000..03682966b --- /dev/null +++ b/tests/bot/rules/test_burst.py @@ -0,0 +1,54 @@ +from typing import Iterable + +from bot.rules import burst +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage + + +def make_msg(author: str) -> MockMessage: + """ + Init a MockMessage instance with author set to `author`. + + This serves as a shorthand / alias to keep the test cases visually clean. + """ + return MockMessage(author=author) + + +class BurstRuleTests(RuleTest): + """Tests the `burst` antispam rule.""" + + def setUp(self): + self.apply = burst.apply + self.config = {"max": 2, "interval": 10} + + async def test_allows_messages_within_limit(self): + """Cases which do not violate the rule.""" + cases = ( + [make_msg("bob"), make_msg("bob")], + [make_msg("bob"), make_msg("alice"), make_msg("bob")], + ) + + await self.run_allowed(cases) + + async def test_disallows_messages_beyond_limit(self): + """Cases where the amount of messages exceeds the limit, triggering the rule.""" + cases = ( + DisallowedCase( + [make_msg("bob"), make_msg("bob"), make_msg("bob")], + ("bob",), + 3, + ), + DisallowedCase( + [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")], + ("bob",), + 3, + ), + ) + + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + return tuple(msg for msg in case.recent_messages if msg.author in case.culprits) + + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} messages in {self.config['interval']}s" diff --git a/tests/bot/rules/test_burst_shared.py b/tests/bot/rules/test_burst_shared.py new file mode 100644 index 000000000..3275143d5 --- /dev/null +++ b/tests/bot/rules/test_burst_shared.py @@ -0,0 +1,57 @@ +from typing import Iterable + +from bot.rules import burst_shared +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage + + +def make_msg(author: str) -> MockMessage: + """ + Init a MockMessage instance with the passed arg. + + This serves as a shorthand / alias to keep the test cases visually clean. + """ + return MockMessage(author=author) + + +class BurstSharedRuleTests(RuleTest): + """Tests the `burst_shared` antispam rule.""" + + def setUp(self): + self.apply = burst_shared.apply + self.config = {"max": 2, "interval": 10} + + async def test_allows_messages_within_limit(self): + """ + Cases that do not violate the rule. + + There really isn't more to test here than a single case. + """ + cases = ( + [make_msg("spongebob"), make_msg("patrick")], + ) + + await self.run_allowed(cases) + + async def test_disallows_messages_beyond_limit(self): + """Cases where the amount of messages exceeds the limit, triggering the rule.""" + cases = ( + DisallowedCase( + [make_msg("bob"), make_msg("bob"), make_msg("bob")], + {"bob"}, + 3, + ), + DisallowedCase( + [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")], + {"bob", "alice"}, + 4, + ), + ) + + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + return case.recent_messages + + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} messages in {self.config['interval']}s" diff --git a/tests/bot/rules/test_chars.py b/tests/bot/rules/test_chars.py new file mode 100644 index 000000000..f1e3c76a7 --- /dev/null +++ b/tests/bot/rules/test_chars.py @@ -0,0 +1,64 @@ +from typing import Iterable + +from bot.rules import chars +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage + + +def make_msg(author: str, n_chars: int) -> MockMessage: + """Build a message with arbitrary content of `n_chars` length.""" + return MockMessage(author=author, content="A" * n_chars) + + +class CharsRuleTests(RuleTest): + """Tests the `chars` antispam rule.""" + + def setUp(self): + self.apply = chars.apply + self.config = { + "max": 20, # Max allowed sum of chars per user + "interval": 10, + } + + async def test_allows_messages_within_limit(self): + """Cases with a total amount of chars within limit.""" + cases = ( + [make_msg("bob", 0)], + [make_msg("bob", 20)], + [make_msg("bob", 15), make_msg("alice", 15)], + ) + + await self.run_allowed(cases) + + async def test_disallows_messages_beyond_limit(self): + """Cases where the total amount of chars exceeds the limit, triggering the rule.""" + cases = ( + DisallowedCase( + [make_msg("bob", 21)], + ("bob",), + 21, + ), + DisallowedCase( + [make_msg("bob", 15), make_msg("bob", 15)], + ("bob",), + 30, + ), + DisallowedCase( + [make_msg("alice", 15), make_msg("bob", 20), make_msg("alice", 15)], + ("alice",), + 30, + ), + ) + + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + last_message = case.recent_messages[0] + return tuple( + msg + for msg in case.recent_messages + if msg.author == last_message.author + ) + + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} characters in {self.config['interval']}s" diff --git a/tests/bot/rules/test_discord_emojis.py b/tests/bot/rules/test_discord_emojis.py new file mode 100644 index 000000000..9a72723e2 --- /dev/null +++ b/tests/bot/rules/test_discord_emojis.py @@ -0,0 +1,52 @@ +from typing import Iterable + +from bot.rules import discord_emojis +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage + +discord_emoji = "<:abcd:1234>" # Discord emojis follow the format <:name:id> + + +def make_msg(author: str, n_emojis: int) -> MockMessage: + """Build a MockMessage instance with content containing `n_emojis` arbitrary emojis.""" + return MockMessage(author=author, content=discord_emoji * n_emojis) + + +class DiscordEmojisRuleTests(RuleTest): + """Tests for the `discord_emojis` antispam rule.""" + + def setUp(self): + self.apply = discord_emojis.apply + self.config = {"max": 2, "interval": 10} + + async def test_allows_messages_within_limit(self): + """Cases with a total amount of discord emojis within limit.""" + cases = ( + [make_msg("bob", 2)], + [make_msg("alice", 1), make_msg("bob", 2), make_msg("alice", 1)], + ) + + await self.run_allowed(cases) + + async def test_disallows_messages_beyond_limit(self): + """Cases with more than the allowed amount of discord emojis.""" + cases = ( + DisallowedCase( + [make_msg("bob", 3)], + ("bob",), + 3, + ), + DisallowedCase( + [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)], + ("alice",), + 4, + ), + ) + + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + return tuple(msg for msg in case.recent_messages if msg.author in case.culprits) + + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} emojis in {self.config['interval']}s" diff --git a/tests/bot/rules/test_duplicates.py b/tests/bot/rules/test_duplicates.py new file mode 100644 index 000000000..9bd886a77 --- /dev/null +++ b/tests/bot/rules/test_duplicates.py @@ -0,0 +1,64 @@ +from typing import Iterable + +from bot.rules import duplicates +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage + + +def make_msg(author: str, content: str) -> MockMessage: + """Give a MockMessage instance with `author` and `content` attrs.""" + return MockMessage(author=author, content=content) + + +class DuplicatesRuleTests(RuleTest): + """Tests the `duplicates` antispam rule.""" + + def setUp(self): + self.apply = duplicates.apply + self.config = {"max": 2, "interval": 10} + + async def test_allows_messages_within_limit(self): + """Cases which do not violate the rule.""" + cases = ( + [make_msg("alice", "A"), make_msg("alice", "A")], + [make_msg("alice", "A"), make_msg("alice", "B"), make_msg("alice", "C")], # Non-duplicate + [make_msg("alice", "A"), make_msg("bob", "A"), make_msg("alice", "A")], # Different author + ) + + await self.run_allowed(cases) + + async def test_disallows_messages_beyond_limit(self): + """Cases with too many duplicate messages from the same author.""" + cases = ( + DisallowedCase( + [make_msg("alice", "A"), make_msg("alice", "A"), make_msg("alice", "A")], + ("alice",), + 3, + ), + DisallowedCase( + [make_msg("bob", "A"), make_msg("alice", "A"), make_msg("bob", "A"), make_msg("bob", "A")], + ("bob",), + 3, # 4 duplicate messages, but only 3 from bob + ), + DisallowedCase( + [make_msg("bob", "A"), make_msg("bob", "B"), make_msg("bob", "A"), make_msg("bob", "A")], + ("bob",), + 3, # 4 message from bob, but only 3 duplicates + ), + ) + + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + last_message = case.recent_messages[0] + return tuple( + msg + for msg in case.recent_messages + if ( + msg.author == last_message.author + and msg.content == last_message.content + ) + ) + + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} duplicated messages in {self.config['interval']}s" diff --git a/tests/bot/rules/test_links.py b/tests/bot/rules/test_links.py index 02a5d5501..b091bd9d7 100644 --- a/tests/bot/rules/test_links.py +++ b/tests/bot/rules/test_links.py @@ -1,97 +1,67 @@ -import unittest -from typing import List, NamedTuple, Tuple +from typing import Iterable from bot.rules import links -from tests.helpers import MockMessage, async_test +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage -class Case(NamedTuple): - recent_messages: List[MockMessage] - culprit: Tuple[str] - total_links: int - - -def msg(author: str, total_links: int) -> MockMessage: +def make_msg(author: str, total_links: int) -> MockMessage: """Makes a message with `total_links` links.""" content = " ".join(["https://pydis.com"] * total_links) return MockMessage(author=author, content=content) -class LinksTests(unittest.TestCase): +class LinksTests(RuleTest): """Tests applying the `links` rule.""" def setUp(self): + self.apply = links.apply self.config = { "max": 2, "interval": 10 } - @async_test async def test_links_within_limit(self): """Messages with an allowed amount of links.""" cases = ( - [msg("bob", 0)], - [msg("bob", 2)], - [msg("bob", 3)], # Filter only applies if len(messages_with_links) > 1 - [msg("bob", 1), msg("bob", 1)], - [msg("bob", 2), msg("alice", 2)] # Only messages from latest author count + [make_msg("bob", 0)], + [make_msg("bob", 2)], + [make_msg("bob", 3)], # Filter only applies if len(messages_with_links) > 1 + [make_msg("bob", 1), make_msg("bob", 1)], + [make_msg("bob", 2), make_msg("alice", 2)] # Only messages from latest author count ) - for recent_messages in cases: - last_message = recent_messages[0] - - with self.subTest( - last_message=last_message, - recent_messages=recent_messages, - config=self.config - ): - self.assertIsNone( - await links.apply(last_message, recent_messages, self.config) - ) + await self.run_allowed(cases) - @async_test async def test_links_exceeding_limit(self): """Messages with a a higher than allowed amount of links.""" cases = ( - Case( - [msg("bob", 1), msg("bob", 2)], + DisallowedCase( + [make_msg("bob", 1), make_msg("bob", 2)], ("bob",), 3 ), - Case( - [msg("alice", 1), msg("alice", 1), msg("alice", 1)], + DisallowedCase( + [make_msg("alice", 1), make_msg("alice", 1), make_msg("alice", 1)], ("alice",), 3 ), - Case( - [msg("alice", 2), msg("bob", 3), msg("alice", 1)], + DisallowedCase( + [make_msg("alice", 2), make_msg("bob", 3), make_msg("alice", 1)], ("alice",), 3 ) ) - for recent_messages, culprit, total_links in cases: - last_message = recent_messages[0] - relevant_messages = tuple( - msg - for msg in recent_messages - if msg.author == last_message.author - ) + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + last_message = case.recent_messages[0] + return tuple( + msg + for msg in case.recent_messages + if msg.author == last_message.author + ) - with self.subTest( - last_message=last_message, - recent_messages=recent_messages, - relevant_messages=relevant_messages, - culprit=culprit, - total_links=total_links, - config=self.config - ): - desired_output = ( - f"sent {total_links} links in {self.config['interval']}s", - culprit, - relevant_messages - ) - self.assertTupleEqual( - await links.apply(last_message, recent_messages, self.config), - desired_output - ) + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} links in {self.config['interval']}s" diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py index ad49ead32..6444532f2 100644 --- a/tests/bot/rules/test_mentions.py +++ b/tests/bot/rules/test_mentions.py @@ -1,95 +1,65 @@ -import unittest -from typing import List, NamedTuple, Tuple +from typing import Iterable from bot.rules import mentions -from tests.helpers import MockMessage, async_test +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage -class Case(NamedTuple): - recent_messages: List[MockMessage] - culprit: Tuple[str] - total_mentions: int - - -def msg(author: str, total_mentions: int) -> MockMessage: +def make_msg(author: str, total_mentions: int) -> MockMessage: """Makes a message with `total_mentions` mentions.""" return MockMessage(author=author, mentions=list(range(total_mentions))) -class TestMentions(unittest.TestCase): +class TestMentions(RuleTest): """Tests applying the `mentions` antispam rule.""" def setUp(self): + self.apply = mentions.apply self.config = { "max": 2, - "interval": 10 + "interval": 10, } - @async_test async def test_mentions_within_limit(self): """Messages with an allowed amount of mentions.""" cases = ( - [msg("bob", 0)], - [msg("bob", 2)], - [msg("bob", 1), msg("bob", 1)], - [msg("bob", 1), msg("alice", 2)] + [make_msg("bob", 0)], + [make_msg("bob", 2)], + [make_msg("bob", 1), make_msg("bob", 1)], + [make_msg("bob", 1), make_msg("alice", 2)], ) - for recent_messages in cases: - last_message = recent_messages[0] - - with self.subTest( - last_message=last_message, - recent_messages=recent_messages, - config=self.config - ): - self.assertIsNone( - await mentions.apply(last_message, recent_messages, self.config) - ) + await self.run_allowed(cases) - @async_test async def test_mentions_exceeding_limit(self): """Messages with a higher than allowed amount of mentions.""" cases = ( - Case( - [msg("bob", 3)], + DisallowedCase( + [make_msg("bob", 3)], ("bob",), - 3 + 3, ), - Case( - [msg("alice", 2), msg("alice", 0), msg("alice", 1)], + DisallowedCase( + [make_msg("alice", 2), make_msg("alice", 0), make_msg("alice", 1)], ("alice",), - 3 + 3, ), - Case( - [msg("bob", 2), msg("alice", 3), msg("bob", 2)], + DisallowedCase( + [make_msg("bob", 2), make_msg("alice", 3), make_msg("bob", 2)], ("bob",), - 4 + 4, ) ) - for recent_messages, culprit, total_mentions in cases: - last_message = recent_messages[0] - relevant_messages = tuple( - msg - for msg in recent_messages - if msg.author == last_message.author - ) + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + last_message = case.recent_messages[0] + return tuple( + msg + for msg in case.recent_messages + if msg.author == last_message.author + ) - with self.subTest( - last_message=last_message, - recent_messages=recent_messages, - relevant_messages=relevant_messages, - culprit=culprit, - total_mentions=total_mentions, - cofig=self.config - ): - desired_output = ( - f"sent {total_mentions} mentions in {self.config['interval']}s", - culprit, - relevant_messages - ) - self.assertTupleEqual( - await mentions.apply(last_message, recent_messages, self.config), - desired_output - ) + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} mentions in {self.config['interval']}s" diff --git a/tests/bot/rules/test_newlines.py b/tests/bot/rules/test_newlines.py new file mode 100644 index 000000000..e35377773 --- /dev/null +++ b/tests/bot/rules/test_newlines.py @@ -0,0 +1,102 @@ +from typing import Iterable, List + +from bot.rules import newlines +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage + + +def make_msg(author: str, newline_groups: List[int]) -> MockMessage: + """Init a MockMessage instance with `author` and content configured by `newline_groups". + + Configure content by passing a list of ints, where each int `n` will generate + a separate group of `n` newlines. + + Example: + newline_groups=[3, 1, 2] -> content="\n\n\n \n \n\n" + """ + content = " ".join("\n" * n for n in newline_groups) + return MockMessage(author=author, content=content) + + +class TotalNewlinesRuleTests(RuleTest): + """Tests the `newlines` antispam rule against allowed cases and total newline count violations.""" + + def setUp(self): + self.apply = newlines.apply + self.config = { + "max": 5, # Max sum of newlines in relevant messages + "max_consecutive": 3, # Max newlines in one group, in one message + "interval": 10, + } + + async def test_allows_messages_within_limit(self): + """Cases which do not violate the rule.""" + cases = ( + [make_msg("alice", [])], # Single message with no newlines + [make_msg("alice", [1, 2]), make_msg("alice", [1, 1])], # 5 newlines in 2 messages + [make_msg("alice", [2, 2, 1]), make_msg("bob", [2, 3])], # 5 newlines from each author + [make_msg("bob", [1]), make_msg("alice", [5])], # Alice breaks the rule, but only bob is relevant + ) + + await self.run_allowed(cases) + + async def test_disallows_messages_total(self): + """Cases which violate the rule by having too many newlines in total.""" + cases = ( + DisallowedCase( # Alice sends a total of 6 newlines (disallowed) + [make_msg("alice", [2, 2]), make_msg("alice", [2])], + ("alice",), + 6, + ), + DisallowedCase( # Here we test that only alice's newlines count in the sum + [make_msg("alice", [2, 2]), make_msg("bob", [3]), make_msg("alice", [3])], + ("alice",), + 7, + ), + ) + + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + last_author = case.recent_messages[0].author + return tuple(msg for msg in case.recent_messages if msg.author == last_author) + + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} newlines in {self.config['interval']}s" + + +class GroupNewlinesRuleTests(RuleTest): + """ + Tests the `newlines` antispam rule against max consecutive newline violations. + + As these violations yield a different error report, they require a different + `get_report` implementation. + """ + + def setUp(self): + self.apply = newlines.apply + self.config = {"max": 5, "max_consecutive": 3, "interval": 10} + + async def test_disallows_messages_consecutive(self): + """Cases which violate the rule due to having too many consecutive newlines.""" + cases = ( + DisallowedCase( # Bob sends a group of newlines too large + [make_msg("bob", [4])], + ("bob",), + 4, + ), + DisallowedCase( # Alice sends 5 in total (allowed), but 4 in one group (disallowed) + [make_msg("alice", [1]), make_msg("alice", [4])], + ("alice",), + 4, + ), + ) + + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + last_author = case.recent_messages[0].author + return tuple(msg for msg in case.recent_messages if msg.author == last_author) + + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} consecutive newlines in {self.config['interval']}s" diff --git a/tests/bot/rules/test_role_mentions.py b/tests/bot/rules/test_role_mentions.py new file mode 100644 index 000000000..26c05d527 --- /dev/null +++ b/tests/bot/rules/test_role_mentions.py @@ -0,0 +1,55 @@ +from typing import Iterable + +from bot.rules import role_mentions +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage + + +def make_msg(author: str, n_mentions: int) -> MockMessage: + """Build a MockMessage instance with `n_mentions` role mentions.""" + return MockMessage(author=author, role_mentions=[None] * n_mentions) + + +class RoleMentionsRuleTests(RuleTest): + """Tests for the `role_mentions` antispam rule.""" + + def setUp(self): + self.apply = role_mentions.apply + self.config = {"max": 2, "interval": 10} + + async def test_allows_messages_within_limit(self): + """Cases with a total amount of role mentions within limit.""" + cases = ( + [make_msg("bob", 2)], + [make_msg("bob", 1), make_msg("alice", 1), make_msg("bob", 1)], + ) + + await self.run_allowed(cases) + + async def test_disallows_messages_beyond_limit(self): + """Cases with more than the allowed amount of role mentions.""" + cases = ( + DisallowedCase( + [make_msg("bob", 3)], + ("bob",), + 3, + ), + DisallowedCase( + [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)], + ("alice",), + 4, + ), + ) + + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + last_message = case.recent_messages[0] + return tuple( + msg + for msg in case.recent_messages + if msg.author == last_message.author + ) + + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} role mentions in {self.config['interval']}s" diff --git a/tests/bot/test_api.py b/tests/bot/test_api.py index 5a88adc5c..99e942813 100644 --- a/tests/bot/test_api.py +++ b/tests/bot/test_api.py @@ -1,13 +1,10 @@ -import logging import unittest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock from bot import api -from tests.base import LoggingTestCase -from tests.helpers import async_test -class APIClientTests(unittest.TestCase): +class APIClientTests(unittest.IsolatedAsyncioTestCase): """Tests for the bot's API client.""" @classmethod @@ -20,7 +17,6 @@ class APIClientTests(unittest.TestCase): """The event loop should not be running by default.""" self.assertFalse(api.loop_is_running()) - @async_test async def test_loop_is_running_in_async_context(self): """The event loop should be running in an async context.""" self.assertTrue(api.loop_is_running()) @@ -34,7 +30,7 @@ class APIClientTests(unittest.TestCase): self.assertEqual(error.response_text, "") self.assertIs(error.response, self.error_api_response) - def test_responde_code_error_string_representation_default_initialization(self): + def test_response_code_error_string_representation_default_initialization(self): """Test the string representation of `ResponseCodeError` initialized without text or json.""" error = api.ResponseCodeError(response=self.error_api_response) self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: ") @@ -76,61 +72,3 @@ class APIClientTests(unittest.TestCase): response_text=text_data ) self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: {text_data}") - - -class LoggingHandlerTests(LoggingTestCase): - """Tests the bot's API Log Handler.""" - - @classmethod - def setUpClass(cls): - cls.debug_log_record = logging.LogRecord( - name='my.logger', level=logging.DEBUG, - pathname='my/logger.py', lineno=666, - msg="Lemon wins", args=(), - exc_info=None - ) - - cls.trace_log_record = logging.LogRecord( - name='my.logger', level=logging.TRACE, - pathname='my/logger.py', lineno=666, - msg="This will not be logged", args=(), - exc_info=None - ) - - def setUp(self): - self.log_handler = api.APILoggingHandler(None) - - def test_emit_appends_to_queue_with_stopped_event_loop(self): - """Test if `APILoggingHandler.emit` appends to queue when the event loop is not running.""" - with patch("bot.api.APILoggingHandler.ship_off") as ship_off: - # Patch `ship_off` to ease testing against the return value of this coroutine. - ship_off.return_value = 42 - self.log_handler.emit(self.debug_log_record) - - self.assertListEqual(self.log_handler.queue, [42]) - - def test_emit_ignores_less_than_debug(self): - """`APILoggingHandler.emit` should not queue logs with a log level lower than DEBUG.""" - self.log_handler.emit(self.trace_log_record) - self.assertListEqual(self.log_handler.queue, []) - - def test_schedule_queued_tasks_for_empty_queue(self): - """`APILoggingHandler` should not schedule anything when the queue is empty.""" - with self.assertNotLogs(level=logging.DEBUG): - self.log_handler.schedule_queued_tasks() - - def test_schedule_queued_tasks_for_nonempty_queue(self): - """`APILoggingHandler` should schedule logs when the queue is not empty.""" - log = logging.getLogger("bot.api") - - with self.assertLogs(logger=log, level=logging.DEBUG) as logs, patch('asyncio.create_task') as create_task: - self.log_handler.queue = [555] - self.log_handler.schedule_queued_tasks() - self.assertListEqual(self.log_handler.queue, []) - create_task.assert_called_once_with(555) - - [record] = logs.records - self.assertEqual(record.message, "Scheduled 1 pending logging tasks.") - self.assertEqual(record.levelno, logging.DEBUG) - self.assertEqual(record.name, 'bot.api') - self.assertIn('via_handler', record.__dict__) diff --git a/tests/bot/test_converters.py b/tests/bot/test_converters.py index b2b78d9dd..1e5ca62ae 100644 --- a/tests/bot/test_converters.py +++ b/tests/bot/test_converters.py @@ -68,7 +68,7 @@ class ConverterTests(unittest.TestCase): ('👋', "Don't be ridiculous, you can't use that character!"), ('', "Tag names should not be empty, or filled with whitespace."), (' ', "Tag names should not be empty, or filled with whitespace."), - ('42', "Tag names can't be numbers."), + ('42', "Tag names must contain at least one letter."), ('x' * 128, "Are you insane? That's way too long!"), ) diff --git a/tests/bot/test_utils.py b/tests/bot/test_utils.py deleted file mode 100644 index 58ae2a81a..000000000 --- a/tests/bot/test_utils.py +++ /dev/null @@ -1,52 +0,0 @@ -import unittest - -from bot import utils - - -class CaseInsensitiveDictTests(unittest.TestCase): - """Tests for the `CaseInsensitiveDict` container.""" - - def test_case_insensitive_key_access(self): - """Tests case insensitive key access and storage.""" - instance = utils.CaseInsensitiveDict() - - key = 'LEMON' - value = 'trees' - - instance[key] = value - self.assertIn(key, instance) - self.assertEqual(instance.get(key), value) - self.assertEqual(instance.get(key.casefold()), value) - self.assertEqual(instance.pop(key.casefold()), value) - self.assertNotIn(key, instance) - self.assertNotIn(key.casefold(), instance) - - instance.setdefault(key, value) - del instance[key] - self.assertNotIn(key, instance) - - def test_initialization_from_kwargs(self): - """Tests creating the dictionary from keyword arguments.""" - instance = utils.CaseInsensitiveDict({'FOO': 'bar'}) - self.assertEqual(instance['foo'], 'bar') - - def test_update_from_other_mapping(self): - """Tests updating the dictionary from another mapping.""" - instance = utils.CaseInsensitiveDict() - instance.update({'FOO': 'bar'}) - self.assertEqual(instance['foo'], 'bar') - - -class ChunkTests(unittest.TestCase): - """Tests the `chunk` method.""" - - def test_empty_chunking(self): - """Tests chunking on an empty iterable.""" - generator = utils.chunks(iterable=[], size=5) - self.assertEqual(list(generator), []) - - def test_list_chunking(self): - """Tests chunking a non-empty list.""" - iterable = [1, 2, 3, 4, 5] - generator = utils.chunks(iterable=iterable, size=2) - self.assertEqual(list(generator), [[1, 2], [3, 4], [5]]) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py index 69f35f2f5..694d3a40f 100644 --- a/tests/bot/utils/test_time.py +++ b/tests/bot/utils/test_time.py @@ -1,12 +1,11 @@ import asyncio import unittest from datetime import datetime, timezone -from unittest.mock import patch +from unittest.mock import AsyncMock, patch from dateutil.relativedelta import relativedelta from bot.utils import time -from tests.helpers import AsyncMock class TimeTests(unittest.TestCase): @@ -44,7 +43,7 @@ class TimeTests(unittest.TestCase): for max_units in test_cases: with self.subTest(max_units=max_units), self.assertRaises(ValueError) as error: time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units) - self.assertEqual(str(error), 'max_units must be positive') + self.assertEqual(str(error.exception), 'max_units must be positive') def test_parse_rfc1123(self): """Testing parse_rfc1123.""" diff --git a/tests/helpers.py b/tests/helpers.py index 5df796c23..8e13f0f28 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,17 +1,15 @@ from __future__ import annotations -import asyncio import collections -import functools -import inspect import itertools import logging import unittest.mock -from typing import Any, Iterable, Optional +from typing import Iterable, Optional import discord from discord.ext.commands import Context +from bot.api import APIClient from bot.bot import Bot @@ -25,21 +23,6 @@ for logger in logging.Logger.manager.loggerDict.values(): logger.setLevel(logging.CRITICAL) -def async_test(wrapped): - """ - Run a test case via asyncio. - Example: - >>> @async_test - ... async def lemon_wins(): - ... assert True - """ - - @functools.wraps(wrapped) - def wrapper(*args, **kwargs): - return asyncio.run(wrapped(*args, **kwargs)) - return wrapper - - class HashableMixin(discord.mixins.EqualityComparable): """ Mixin that provides similar hashing and equality functionality as discord.py's `Hashable` mixin. @@ -68,24 +51,31 @@ class CustomMockMixin: """ Provides common functionality for our custom Mock types. - The cooperative `__init__` automatically creates `AsyncMock` attributes for every coroutine - function `inspect` detects in the `spec` instance we provide. In addition, this mixin takes care - of making sure child mocks are instantiated with the correct class. By default, the mock of the - children will be `unittest.mock.MagicMock`, but this can be overwritten by setting the attribute - `child_mock_type` on the custom mock inheriting from this mixin. + The `_get_child_mock` method automatically returns an AsyncMock for coroutine methods of the mock + object. As discord.py also uses synchronous methods that nonetheless return coroutine objects, the + class attribute `additional_spec_asyncs` can be overwritten with an iterable containing additional + attribute names that should also mocked with an AsyncMock instead of a regular MagicMock/Mock. The + class method `spec_set` can be overwritten with the object that should be uses as the specification + for the mock. + + Mock/MagicMock subclasses that use this mixin only need to define `__init__` method if they need to + implement custom behavior. """ child_mock_type = unittest.mock.MagicMock discord_id = itertools.count(0) + spec_set = None + additional_spec_asyncs = None - def __init__(self, spec_set: Any = None, **kwargs): + def __init__(self, **kwargs): name = kwargs.pop('name', None) # `name` has special meaning for Mock classes, so we need to set it manually. - super().__init__(spec_set=spec_set, **kwargs) + super().__init__(spec_set=self.spec_set, **kwargs) + + if self.additional_spec_asyncs: + self._spec_asyncs.extend(self.additional_spec_asyncs) if name: self.name = name - if spec_set: - self._extract_coroutine_methods_from_spec_instance(spec_set) def _get_child_mock(self, **kw): """ @@ -99,7 +89,16 @@ class CustomMockMixin: This override will look for an attribute called `child_mock_type` and use that as the type of the child mock. """ - klass = self.child_mock_type + _new_name = kw.get("_new_name") + if _new_name in self.__dict__['_spec_asyncs']: + return unittest.mock.AsyncMock(**kw) + + _type = type(self) + if issubclass(_type, unittest.mock.MagicMock) and _new_name in unittest.mock._async_method_magics: + # Any asynchronous magic becomes an AsyncMock + klass = unittest.mock.AsyncMock + else: + klass = self.child_mock_type if self._mock_sealed: attribute = "." + kw["name"] if "name" in kw else "()" @@ -108,95 +107,6 @@ class CustomMockMixin: return klass(**kw) - def _extract_coroutine_methods_from_spec_instance(self, source: Any) -> None: - """Automatically detect coroutine functions in `source` and set them as AsyncMock attributes.""" - for name, _method in inspect.getmembers(source, inspect.iscoroutinefunction): - setattr(self, name, AsyncMock()) - - -# TODO: Remove me in Python 3.8 -class AsyncMock(CustomMockMixin, unittest.mock.MagicMock): - """ - A MagicMock subclass to mock async callables. - - Python 3.8 will introduce an AsyncMock class in the standard library that will have some more - features; this stand-in only overwrites the `__call__` method to an async version. - """ - - async def __call__(self, *args, **kwargs): - return super().__call__(*args, **kwargs) - - -class AsyncIteratorMock: - """ - A class to mock asynchronous iterators. - - This allows async for, which is used in certain Discord.py objects. For example, - an async iterator is returned by the Reaction.users() method. - """ - - def __init__(self, iterable: Iterable = None): - if iterable is None: - iterable = [] - - self.iter = iter(iterable) - self.iterable = iterable - - self.call_count = 0 - - def __aiter__(self): - return self - - async def __anext__(self): - try: - return next(self.iter) - except StopIteration: - raise StopAsyncIteration - - def __call__(self): - """ - Keeps track of the number of times an instance has been called. - - This is useful, since it typically shows that the iterator has actually been used somewhere after we have - instantiated the mock for an attribute that normally returns an iterator when called. - """ - self.call_count += 1 - return self - - @property - def return_value(self): - """Makes `self.iterable` accessible as self.return_value.""" - return self.iterable - - @return_value.setter - def return_value(self, iterable): - """Stores the `return_value` as `self.iterable` and its iterator as `self.iter`.""" - self.iter = iter(iterable) - self.iterable = iterable - - def assert_called(self): - """Asserts if the AsyncIteratorMock instance has been called at least once.""" - if self.call_count == 0: - raise AssertionError("Expected AsyncIteratorMock to have been called.") - - def assert_called_once(self): - """Asserts if the AsyncIteratorMock instance has been called exactly once.""" - if self.call_count != 1: - raise AssertionError( - f"Expected AsyncIteratorMock to have been called once. Called {self.call_count} times." - ) - - def assert_not_called(self): - """Asserts if the AsyncIteratorMock instance has not been called.""" - if self.call_count != 0: - raise AssertionError( - f"Expected AsyncIteratorMock to not have been called once. Called {self.call_count} times." - ) - - def reset_mock(self): - """Resets the call count, but not the return value or iterator.""" - self.call_count = 0 - # Create a guild instance to get a realistic Mock of `discord.Guild` guild_data = { @@ -247,9 +157,11 @@ class MockGuild(CustomMockMixin, unittest.mock.Mock, HashableMixin): For more info, see the `Mocking` section in `tests/README.md`. """ + spec_set = guild_instance + def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None: default_kwargs = {'id': next(self.discord_id), 'members': []} - super().__init__(spec_set=guild_instance, **collections.ChainMap(kwargs, default_kwargs)) + super().__init__(**collections.ChainMap(kwargs, default_kwargs)) self.roles = [MockRole(name="@everyone", position=1, id=0)] if roles: @@ -268,9 +180,23 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): Instances of this class will follow the specifications of `discord.Role` instances. For more information, see the `MockGuild` docstring. """ + spec_set = role_instance + def __init__(self, **kwargs) -> None: - default_kwargs = {'id': next(self.discord_id), 'name': 'role', 'position': 1} - super().__init__(spec_set=role_instance, **collections.ChainMap(kwargs, default_kwargs)) + default_kwargs = { + 'id': next(self.discord_id), + 'name': 'role', + 'position': 1, + 'colour': discord.Colour(0xdeadbf), + 'permissions': discord.Permissions(), + } + super().__init__(**collections.ChainMap(kwargs, default_kwargs)) + + if isinstance(self.colour, int): + self.colour = discord.Colour(self.colour) + + if isinstance(self.permissions, int): + self.permissions = discord.Permissions(self.permissions) if 'mention' not in kwargs: self.mention = f'&{self.name}' @@ -293,9 +219,11 @@ class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin Instances of this class will follow the specifications of `discord.Member` instances. For more information, see the `MockGuild` docstring. """ + spec_set = member_instance + def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None: default_kwargs = {'name': 'member', 'id': next(self.discord_id), 'bot': False} - super().__init__(spec_set=member_instance, **collections.ChainMap(kwargs, default_kwargs)) + super().__init__(**collections.ChainMap(kwargs, default_kwargs)) self.roles = [MockRole(name="@everyone", position=1, id=0)] if roles: @@ -316,14 +244,26 @@ class MockUser(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): Instances of this class will follow the specifications of `discord.User` instances. For more information, see the `MockGuild` docstring. """ + spec_set = user_instance + def __init__(self, **kwargs) -> None: default_kwargs = {'name': 'user', 'id': next(self.discord_id), 'bot': False} - super().__init__(spec_set=user_instance, **collections.ChainMap(kwargs, default_kwargs)) + super().__init__(**collections.ChainMap(kwargs, default_kwargs)) if 'mention' not in kwargs: self.mention = f"@{self.name}" +class MockAPIClient(CustomMockMixin, unittest.mock.MagicMock): + """ + A MagicMock subclass to mock APIClient objects. + + Instances of this class will follow the specifications of `bot.api.APIClient` instances. + For more information, see the `MockGuild` docstring. + """ + spec_set = APIClient + + # Create a Bot instance to get a realistic MagicMock of `discord.ext.commands.Bot` bot_instance = Bot(command_prefix=unittest.mock.MagicMock()) bot_instance.http_session = None @@ -337,14 +277,12 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances. For more information, see the `MockGuild` docstring. """ + spec_set = bot_instance + additional_spec_asyncs = ("wait_for",) def __init__(self, **kwargs) -> None: - super().__init__(spec_set=bot_instance, **kwargs) - - # self.wait_for is *not* a coroutine function, but returns a coroutine nonetheless and - # and should therefore be awaited. (The documentation calls it a coroutine as well, which - # is technically incorrect, since it's a regular def.) - self.wait_for = AsyncMock() + super().__init__(**kwargs) + self.api_client = MockAPIClient() # Since calling `create_task` on our MockBot does not actually schedule the coroutine object # as a task in the asyncio loop, this `side_effect` calls `close()` on the coroutine object @@ -375,10 +313,11 @@ class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): Instances of this class will follow the specifications of `discord.TextChannel` instances. For more information, see the `MockGuild` docstring. """ + spec_set = channel_instance def __init__(self, name: str = 'channel', channel_id: int = 1, **kwargs) -> None: default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()} - super().__init__(spec_set=channel_instance, **collections.ChainMap(kwargs, default_kwargs)) + super().__init__(**collections.ChainMap(kwargs, default_kwargs)) if 'mention' not in kwargs: self.mention = f"#{self.name}" @@ -417,9 +356,10 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.ext.commands.Context` instances. For more information, see the `MockGuild` docstring. """ + spec_set = context_instance def __init__(self, **kwargs) -> None: - super().__init__(spec_set=context_instance, **kwargs) + super().__init__(**kwargs) self.bot = kwargs.get('bot', MockBot()) self.guild = kwargs.get('guild', MockGuild()) self.author = kwargs.get('author', MockMember()) @@ -436,8 +376,7 @@ class MockAttachment(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.Attachment` instances. For more information, see the `MockGuild` docstring. """ - def __init__(self, **kwargs) -> None: - super().__init__(spec_set=attachment_instance, **kwargs) + spec_set = attachment_instance class MockMessage(CustomMockMixin, unittest.mock.MagicMock): @@ -447,10 +386,11 @@ class MockMessage(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.Message` instances. For more information, see the `MockGuild` docstring. """ + spec_set = message_instance def __init__(self, **kwargs) -> None: default_kwargs = {'attachments': []} - super().__init__(spec_set=message_instance, **collections.ChainMap(kwargs, default_kwargs)) + super().__init__(**collections.ChainMap(kwargs, default_kwargs)) self.author = kwargs.get('author', MockMember()) self.channel = kwargs.get('channel', MockTextChannel()) @@ -466,9 +406,10 @@ class MockEmoji(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.Emoji` instances. For more information, see the `MockGuild` docstring. """ + spec_set = emoji_instance def __init__(self, **kwargs) -> None: - super().__init__(spec_set=emoji_instance, **kwargs) + super().__init__(**kwargs) self.guild = kwargs.get('guild', MockGuild()) @@ -482,9 +423,7 @@ class MockPartialEmoji(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.PartialEmoji` instances. For more information, see the `MockGuild` docstring. """ - - def __init__(self, **kwargs) -> None: - super().__init__(spec_set=partial_emoji_instance, **kwargs) + spec_set = partial_emoji_instance reaction_instance = discord.Reaction(message=MockMessage(), data={'me': True}, emoji=MockEmoji()) @@ -497,12 +436,19 @@ class MockReaction(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.Reaction` instances. For more information, see the `MockGuild` docstring. """ + spec_set = reaction_instance def __init__(self, **kwargs) -> None: - super().__init__(spec_set=reaction_instance, **kwargs) + _users = kwargs.pop("users", []) + super().__init__(**kwargs) self.emoji = kwargs.get('emoji', MockEmoji()) self.message = kwargs.get('message', MockMessage()) - self.users = AsyncIteratorMock(kwargs.get('users', [])) + + user_iterator = unittest.mock.AsyncMock() + user_iterator.__aiter__.return_value = _users + self.users.return_value = user_iterator + + self.__str__.return_value = str(self.emoji) webhook_instance = discord.Webhook(data=unittest.mock.MagicMock(), adapter=unittest.mock.MagicMock()) @@ -515,13 +461,5 @@ class MockAsyncWebhook(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.Webhook` instances. For more information, see the `MockGuild` docstring. """ - - def __init__(self, **kwargs) -> None: - super().__init__(spec_set=webhook_instance, **kwargs) - - # Because Webhooks can also use a synchronous "WebhookAdapter", the methods are not defined - # as coroutines. That's why we need to set the methods manually. - self.send = AsyncMock() - self.edit = AsyncMock() - self.delete = AsyncMock() - self.execute = AsyncMock() + spec_set = webhook_instance + additional_spec_asyncs = ("send", "edit", "delete", "execute") diff --git a/tests/test_base.py b/tests/test_base.py index a16e2af8f..a7db4bf3e 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -3,7 +3,11 @@ import unittest import unittest.mock -from tests.base import LoggingTestCase, _CaptureLogHandler +from tests.base import LoggingTestsMixin, _CaptureLogHandler + + +class LoggingTestCase(LoggingTestsMixin, unittest.TestCase): + pass class LoggingTestCaseTests(unittest.TestCase): @@ -18,24 +22,14 @@ class LoggingTestCaseTests(unittest.TestCase): try: with LoggingTestCase.assertNotLogs(self, level=logging.DEBUG): pass - except AssertionError: + except AssertionError: # pragma: no cover self.fail("`self.assertNotLogs` raised an AssertionError when it should not!") - @unittest.mock.patch("tests.base.LoggingTestCase.assertNotLogs") - def test_the_test_function_assert_not_logs_does_not_raise_with_no_logs(self, assertNotLogs): - """Test if test_assert_not_logs_does_not_raise_with_no_logs captures exception correctly.""" - assertNotLogs.return_value = iter([None]) - assertNotLogs.side_effect = AssertionError - - message = "`self.assertNotLogs` raised an AssertionError when it should not!" - with self.assertRaises(AssertionError, msg=message): - self.test_assert_not_logs_does_not_raise_with_no_logs() - def test_assert_not_logs_raises_correct_assertion_error_when_logs_are_emitted(self): """Test if LoggingTestCase.assertNotLogs raises AssertionError when logs were emitted.""" msg_regex = ( r"1 logs of DEBUG or higher were triggered on root:\n" - r'<LogRecord: tests\.test_base, [\d]+, .+/tests/test_base\.py, [\d]+, "Log!">' + r'<LogRecord: tests\.test_base, [\d]+, .+[/\\]tests[/\\]test_base\.py, [\d]+, "Log!">' ) with self.assertRaisesRegex(AssertionError, msg_regex): with LoggingTestCase.assertNotLogs(self, level=logging.DEBUG): diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 7894e104a..81285e009 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,5 +1,4 @@ import asyncio -import inspect import unittest import unittest.mock @@ -214,6 +213,11 @@ class DiscordMocksTests(unittest.TestCase): with self.assertRaises(RuntimeError, msg="cannot reuse already awaited coroutine"): asyncio.run(coroutine_object) + def test_user_mock_uses_explicitly_passed_mention_attribute(self): + """MockUser should use an explicitly passed value for user.mention.""" + user = helpers.MockUser(mention="hello") + self.assertEqual(user.mention, "hello") + class MockObjectTests(unittest.TestCase): """Tests the mock objects and mixins we've defined.""" @@ -341,65 +345,10 @@ class MockObjectTests(unittest.TestCase): attribute = getattr(mock, valid_attribute) self.assertTrue(isinstance(attribute, mock_type.child_mock_type)) - def test_extract_coroutine_methods_from_spec_instance_should_extract_all_and_only_coroutines(self): - """Test if all coroutine functions are extracted, but not regular methods or attributes.""" - class CoroutineDonor: - def __init__(self): - self.some_attribute = 'alpha' - - async def first_coroutine(): - """This coroutine function should be extracted.""" - - async def second_coroutine(): - """This coroutine function should be extracted.""" - - def regular_method(): - """This regular function should not be extracted.""" - - class Receiver: + def test_custom_mock_mixin_mocks_async_magic_methods_with_async_mock(self): + """The CustomMockMixin should mock async magic methods with an AsyncMock.""" + class MyMock(helpers.CustomMockMixin, unittest.mock.MagicMock): pass - donor = CoroutineDonor() - receiver = Receiver() - - helpers.CustomMockMixin._extract_coroutine_methods_from_spec_instance(receiver, donor) - - self.assertIsInstance(receiver.first_coroutine, helpers.AsyncMock) - self.assertIsInstance(receiver.second_coroutine, helpers.AsyncMock) - self.assertFalse(hasattr(receiver, 'regular_method')) - self.assertFalse(hasattr(receiver, 'some_attribute')) - - @unittest.mock.patch("builtins.super", new=unittest.mock.MagicMock()) - @unittest.mock.patch("tests.helpers.CustomMockMixin._extract_coroutine_methods_from_spec_instance") - def test_custom_mock_mixin_init_with_spec(self, extract_method_mock): - """Test if CustomMockMixin correctly passes on spec/kwargs and calls the extraction method.""" - spec_set = "pydis" - - helpers.CustomMockMixin(spec_set=spec_set) - - extract_method_mock.assert_called_once_with(spec_set) - - @unittest.mock.patch("builtins.super", new=unittest.mock.MagicMock()) - @unittest.mock.patch("tests.helpers.CustomMockMixin._extract_coroutine_methods_from_spec_instance") - def test_custom_mock_mixin_init_without_spec(self, extract_method_mock): - """Test if CustomMockMixin correctly passes on spec/kwargs and calls the extraction method.""" - helpers.CustomMockMixin() - - extract_method_mock.assert_not_called() - - def test_async_mock_provides_coroutine_for_dunder_call(self): - """Test if AsyncMock objects have a coroutine for their __call__ method.""" - async_mock = helpers.AsyncMock() - self.assertTrue(inspect.iscoroutinefunction(async_mock.__call__)) - - coroutine = async_mock() - self.assertTrue(inspect.iscoroutine(coroutine)) - self.assertIsNotNone(asyncio.run(coroutine)) - - def test_async_test_decorator_allows_synchronous_call_to_async_def(self): - """Test if the `async_test` decorator allows an `async def` to be called synchronously.""" - @helpers.async_test - async def kosayoda(): - return "return value" - - self.assertEqual(kosayoda(), "return value") + mock = MyMock() + self.assertIsInstance(mock.__aenter__, unittest.mock.AsyncMock) diff --git a/tests/utils/test_time.py b/tests/utils/test_time.py deleted file mode 100644 index 4baa6395c..000000000 --- a/tests/utils/test_time.py +++ /dev/null @@ -1,62 +0,0 @@ -import asyncio -from datetime import datetime, timezone -from unittest.mock import patch - -import pytest -from dateutil.relativedelta import relativedelta - -from bot.utils import time -from tests.helpers import AsyncMock - - - ('delta', 'precision', 'max_units', 'expected'), - ( - (relativedelta(days=2), 'seconds', 1, '2 days'), - (relativedelta(days=2, hours=2), 'seconds', 2, '2 days and 2 hours'), - (relativedelta(days=2, hours=2), 'seconds', 1, '2 days'), - (relativedelta(days=2, hours=2), 'days', 2, '2 days'), - - # Does not abort for unknown units, as the unit name is checked - # against the attribute of the relativedelta instance. - (relativedelta(days=2, hours=2), 'elephants', 2, '2 days and 2 hours'), - - # Very high maximum units, but it only ever iterates over - # each value the relativedelta might have. - (relativedelta(days=2, hours=2), 'hours', 20, '2 days and 2 hours'), - ) -) -def test_humanize_delta( - delta: relativedelta, - precision: str, - max_units: int, - expected: str -): - assert time.humanize_delta(delta, precision, max_units) == expected - - [email protected]('max_units', (-1, 0)) -def test_humanize_delta_raises_for_invalid_max_units(max_units: int): - with pytest.raises(ValueError, match='max_units must be positive'): - time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units) - - - ('stamp', 'expected'), - ( - ('Sun, 15 Sep 2019 12:00:00 GMT', datetime(2019, 9, 15, 12, 0, 0, tzinfo=timezone.utc)), - ) -) -def test_parse_rfc1123(stamp: str, expected: str): - assert time.parse_rfc1123(stamp) == expected - - -@patch('asyncio.sleep', new_callable=AsyncMock) -def test_wait_until(sleep_patch): - start = datetime(2019, 1, 1, 0, 0) - then = datetime(2019, 1, 1, 0, 10) - - # No return value - assert asyncio.run(time.wait_until(then, start)) is None - - sleep_patch.assert_called_once_with(10 * 60) @@ -3,7 +3,7 @@ max-line-length=120 docstring-convention=all import-order-style=pycharm application_import_names=bot,tests -exclude=.cache,.venv,constants.py +exclude=.cache,.venv,.git,constants.py ignore= B311,W503,E226,S311,T000 # Missing Docstrings @@ -15,5 +15,5 @@ ignore= # Docstring Content D400,D401,D402,D404,D405,D406,D407,D408,D409,D410,D411,D412,D413,D414,D416,D417 # Type Annotations - TYP002,TYP003,TYP101,TYP102,TYP204,TYP206 -per-file-ignores=tests/*:D,TYP + ANN002,ANN003,ANN101,ANN102,ANN204,ANN206 +per-file-ignores=tests/*:D,ANN |