diff options
author | 2019-12-13 23:11:25 +0100 | |
---|---|---|
committer | 2019-12-13 23:11:25 +0100 | |
commit | bf6a576608228a4cefb10150b3c3e082ab7ccf0a (patch) | |
tree | 2d572b5780e6b26dea2b11d3d749ad6c83e4ffb6 | |
parent | Specify assertion to be a tuple comparison (diff) | |
parent | Use OAuth to be Reddit API compliant (#696) (diff) |
Merge branch 'master' into unittest-mentions
61 files changed, 2166 insertions, 558 deletions
diff --git a/Pipfile.lock b/Pipfile.lock index 95955ff89..69caf4646 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -18,11 +18,11 @@ "default": { "aio-pika": { "hashes": [ - "sha256:1dcec3e3e3309e277511dc0d7d157676d0165c174a6a745673fc9cf0510db8f0", - "sha256:dd5a23ca26a4872ee73bd107e4c545bace572cdec2a574aeb61f4062c7774b2a" + "sha256:1da038b3d2c1b49e0e816d87424e702912bb77f9b5197f2bf279217915b4f7ed", + "sha256:29fe851374b86c997a22174c04352b5941bc1c2e36bbf542918ac18a76cfc9d3" ], "index": "pypi", - "version": "==6.1.3" + "version": "==6.3.0" }, "aiodns": { "hashes": [ @@ -62,10 +62,10 @@ }, "aiormq": { "hashes": [ - "sha256:c3e4dd01a2948a75f739fb637334dbb8c6f1a4cecf74d5ed662dc3bab7f39973", - "sha256:e220d3f9477bb2959b729b79bec815148ddb8a7686fc6c3d05d41c88ebd7c59e" + "sha256:afc0d46837b121585e4faec0a7646706429b4e2f5110ae8d0b5cdc3708b4b0e5", + "sha256:dc0fbbc7f8ad5af6a2cc18e00ccc5f925984cde3db6e8fe952c07b7ef157b5f2" ], - "version": "==2.8.0" + "version": "==2.9.1" }, "alabaster": { "hashes": [ @@ -83,10 +83,10 @@ }, "attrs": { "hashes": [ - "sha256:ec20e7a4825331c1b5ebf261d111e16fa9612c1f7a5e1f884f12bd53a664dfd2", - "sha256:f913492e1663d3c36f502e5e9ba6cd13cf19d7fab50aa13239e420fef95e1396" + "sha256:08a96c641c3a74e44eb59afb61a24f2cb9f4d7188748e76ba4bb5edfa3cb7d1c", + "sha256:f7b7ce16570fe9965acd6d30101a28f62fb4a7f9e926b3bbc9b61f8b04247e72" ], - "version": "==19.2.0" + "version": "==19.3.0" }, "babel": { "hashes": [ @@ -112,36 +112,41 @@ }, "cffi": { "hashes": [ - "sha256:041c81822e9f84b1d9c401182e174996f0bae9991f33725d059b771744290774", - "sha256:046ef9a22f5d3eed06334d01b1e836977eeef500d9b78e9ef693f9380ad0b83d", - "sha256:066bc4c7895c91812eff46f4b1c285220947d4aa46fa0a2651ff85f2afae9c90", - "sha256:066c7ff148ae33040c01058662d6752fd73fbc8e64787229ea8498c7d7f4041b", - "sha256:2444d0c61f03dcd26dbf7600cf64354376ee579acad77aef459e34efcb438c63", - "sha256:300832850b8f7967e278870c5d51e3819b9aad8f0a2c8dbe39ab11f119237f45", - "sha256:34c77afe85b6b9e967bd8154e3855e847b70ca42043db6ad17f26899a3df1b25", - "sha256:46de5fa00f7ac09f020729148ff632819649b3e05a007d286242c4882f7b1dc3", - "sha256:4aa8ee7ba27c472d429b980c51e714a24f47ca296d53f4d7868075b175866f4b", - "sha256:4d0004eb4351e35ed950c14c11e734182591465a33e960a4ab5e8d4f04d72647", - "sha256:4e3d3f31a1e202b0f5a35ba3bc4eb41e2fc2b11c1eff38b362de710bcffb5016", - "sha256:50bec6d35e6b1aaeb17f7c4e2b9374ebf95a8975d57863546fa83e8d31bdb8c4", - "sha256:55cad9a6df1e2a1d62063f79d0881a414a906a6962bc160ac968cc03ed3efcfb", - "sha256:5662ad4e4e84f1eaa8efce5da695c5d2e229c563f9d5ce5b0113f71321bcf753", - "sha256:59b4dc008f98fc6ee2bb4fd7fc786a8d70000d058c2bbe2698275bc53a8d3fa7", - "sha256:73e1ffefe05e4ccd7bcea61af76f36077b914f92b76f95ccf00b0c1b9186f3f9", - "sha256:a1f0fd46eba2d71ce1589f7e50a9e2ffaeb739fb2c11e8192aa2b45d5f6cc41f", - "sha256:a2e85dc204556657661051ff4bab75a84e968669765c8a2cd425918699c3d0e8", - "sha256:a5457d47dfff24882a21492e5815f891c0ca35fefae8aa742c6c263dac16ef1f", - "sha256:a8dccd61d52a8dae4a825cdbb7735da530179fea472903eb871a5513b5abbfdc", - "sha256:ae61af521ed676cf16ae94f30fe202781a38d7178b6b4ab622e4eec8cefaff42", - "sha256:b012a5edb48288f77a63dba0840c92d0504aa215612da4541b7b42d849bc83a3", - "sha256:d2c5cfa536227f57f97c92ac30c8109688ace8fa4ac086d19d0af47d134e2909", - "sha256:d42b5796e20aacc9d15e66befb7a345454eef794fdb0737d1af593447c6c8f45", - "sha256:dee54f5d30d775f525894d67b1495625dd9322945e7fee00731952e0368ff42d", - "sha256:e070535507bd6aa07124258171be2ee8dfc19119c28ca94c9dfb7efd23564512", - "sha256:e1ff2748c84d97b065cc95429814cdba39bcbd77c9c85c89344b317dc0d9cbff", - "sha256:ed851c75d1e0e043cbf5ca9a8e1b13c4c90f3fbd863dacb01c0808e2b5204201" - ], - "version": "==1.12.3" + "sha256:0b49274afc941c626b605fb59b59c3485c17dc776dc3cc7cc14aca74cc19cc42", + "sha256:0e3ea92942cb1168e38c05c1d56b0527ce31f1a370f6117f1d490b8dcd6b3a04", + "sha256:135f69aecbf4517d5b3d6429207b2dff49c876be724ac0c8bf8e1ea99df3d7e5", + "sha256:19db0cdd6e516f13329cba4903368bff9bb5a9331d3410b1b448daaadc495e54", + "sha256:2781e9ad0e9d47173c0093321bb5435a9dfae0ed6a762aabafa13108f5f7b2ba", + "sha256:291f7c42e21d72144bb1c1b2e825ec60f46d0a7468f5346841860454c7aa8f57", + "sha256:2c5e309ec482556397cb21ede0350c5e82f0eb2621de04b2633588d118da4396", + "sha256:2e9c80a8c3344a92cb04661115898a9129c074f7ab82011ef4b612f645939f12", + "sha256:32a262e2b90ffcfdd97c7a5e24a6012a43c61f1f5a57789ad80af1d26c6acd97", + "sha256:3c9fff570f13480b201e9ab69453108f6d98244a7f495e91b6c654a47486ba43", + "sha256:415bdc7ca8c1c634a6d7163d43fb0ea885a07e9618a64bda407e04b04333b7db", + "sha256:42194f54c11abc8583417a7cf4eaff544ce0de8187abaf5d29029c91b1725ad3", + "sha256:4424e42199e86b21fc4db83bd76909a6fc2a2aefb352cb5414833c030f6ed71b", + "sha256:4a43c91840bda5f55249413037b7a9b79c90b1184ed504883b72c4df70778579", + "sha256:599a1e8ff057ac530c9ad1778293c665cb81a791421f46922d80a86473c13346", + "sha256:5c4fae4e9cdd18c82ba3a134be256e98dc0596af1e7285a3d2602c97dcfa5159", + "sha256:5ecfa867dea6fabe2a58f03ac9186ea64da1386af2159196da51c4904e11d652", + "sha256:62f2578358d3a92e4ab2d830cd1c2049c9c0d0e6d3c58322993cc341bdeac22e", + "sha256:6471a82d5abea994e38d2c2abc77164b4f7fbaaf80261cb98394d5793f11b12a", + "sha256:6d4f18483d040e18546108eb13b1dfa1000a089bcf8529e30346116ea6240506", + "sha256:71a608532ab3bd26223c8d841dde43f3516aa5d2bf37b50ac410bb5e99053e8f", + "sha256:74a1d8c85fb6ff0b30fbfa8ad0ac23cd601a138f7509dc617ebc65ef305bb98d", + "sha256:7b93a885bb13073afb0aa73ad82059a4c41f4b7d8eb8368980448b52d4c7dc2c", + "sha256:7d4751da932caaec419d514eaa4215eaf14b612cff66398dd51129ac22680b20", + "sha256:7f627141a26b551bdebbc4855c1157feeef18241b4b8366ed22a5c7d672ef858", + "sha256:8169cf44dd8f9071b2b9248c35fc35e8677451c52f795daa2bb4643f32a540bc", + "sha256:aa00d66c0fab27373ae44ae26a66a9e43ff2a678bf63a9c7c1a9a4d61172827a", + "sha256:ccb032fda0873254380aa2bfad2582aedc2959186cce61e3a17abc1a55ff89c3", + "sha256:d754f39e0d1603b5b24a7f8484b22d2904fa551fe865fd0d4c3332f078d20d4e", + "sha256:d75c461e20e29afc0aee7172a0950157c704ff0dd51613506bd7d82b718e7410", + "sha256:dcd65317dd15bc0451f3e01c80da2216a31916bdcffd6221ca1202d96584aa25", + "sha256:e570d3ab32e2c2861c4ebe6ffcad6a8abf9347432a37608fe1fbd157b3f0036b", + "sha256:fd43a88e045cf992ed09fa724b5315b790525f2676883a6ea64e3263bae6549d" + ], + "version": "==1.13.2" }, "chardet": { "hashes": [ @@ -152,18 +157,18 @@ }, "deepdiff": { "hashes": [ - "sha256:1123762580af0904621136d117c8397392a244d3ff0fa0a50de57a7939582476", - "sha256:6ab13e0cbb627dadc312deaca9bef38de88a737a9bbdbfbe6e3857748219c127" + "sha256:3457ea7cecd51ba48015d89edbb569358af4d9b9e65e28bdb3209608420627f9", + "sha256:5e2343398e90538edaa59c0c99207e996a3a834fdc878c666376f632a760c35a" ], "index": "pypi", - "version": "==4.0.7" + "version": "==4.0.9" }, "discord-py": { "hashes": [ - "sha256:4684733fa137cc7def18087ae935af615212e423e3dbbe3e84ef01d7ae8ed17d" + "sha256:7c843b523bb011062b453864e75c7b675a03faf573c58d14c9f096e85984329d" ], "index": "pypi", - "version": "==1.2.3" + "version": "==1.2.5" }, "docutils": { "hashes": [ @@ -221,6 +226,7 @@ "sha256:02ca7bf899da57084041bb0f6095333e4d239948ad3169443f454add9f4e9cb4", "sha256:096b82c5e0ea27ce9138bcbb205313343ee66a6e132f25c5ed67e2c8d960a1bc", "sha256:0a920ff98cf1aac310470c644bc23b326402d3ef667ddafecb024e1713d485f1", + "sha256:1409b14bf83a7d729f92e2a7fbfe7ec929d4883ca071b06e95c539ceedb6497c", "sha256:17cae1730a782858a6e2758fd20dd0ef7567916c47757b694a06ffafdec20046", "sha256:17e3950add54c882e032527795c625929613adbd2ce5162b94667334458b5a36", "sha256:1f4f214337f6ee5825bf90a65d04d70aab05526c08191ab888cb5149501923c5", @@ -231,11 +237,14 @@ "sha256:760c12276fee05c36f95f8040180abc7fbebb9e5011447a97cdc289b5d6ab6fc", "sha256:796685d3969815a633827c818863ee199440696b0961e200b011d79b9394bbe7", "sha256:891fe897b49abb7db470c55664b198b1095e4943b9f82b7dcab317a19116cd38", + "sha256:9277562f175d2334744ad297568677056861070399cec56ff06abbe2564d1232", "sha256:a471628e20f03dcdfde00770eeaf9c77811f0c331c8805219ca7b87ac17576c5", "sha256:a63b4fd3e2cabdcc9d918ed280bdde3e8e9641e04f3c59a2a3109644a07b9832", + "sha256:ae88588d687bd476be588010cbbe551e9c2872b816f2da8f01f6f1fda74e1ef0", "sha256:b0b84408d4eabc6de9dd1e1e0bc63e7731e890c0b378a62443e5741cfd0ae90a", "sha256:be78485e5d5f3684e875dab60f40cddace2f5b2a8f7fede412358ab3214c3a6f", "sha256:c27eaed872185f047bb7f7da2d21a7d8913457678c9a100a50db6da890bc28b9", + "sha256:c7fccd08b14aa437fe096c71c645c0f9be0655a9b1a4b7cffc77bcb23b3d61d2", "sha256:c81cb40bff373ab7a7446d6bbca0190bccc5be3448b47b51d729e37799bb5692", "sha256:d11874b3c33ee441059464711cd365b89fa1a9cf19ae75b0c189b01fbf735b84", "sha256:e9c028b5897901361d81a4718d1db217b716424a0283afe9d6735fe0caf70f79", @@ -379,18 +388,18 @@ }, "pyparsing": { "hashes": [ - "sha256:6f98a7b9397e206d78cc01df10131398f1c8b8510a2f4d97d9abd82e1aacdd80", - "sha256:d9338df12903bbf5d65a0e4e87c2161968b10d2e489652bb47001d82a9b028b4" + "sha256:20f995ecd72f2a1f4bf6b072b63b22e2eb457836601e76d6e5dfcd75436acc1f", + "sha256:4ca62001be367f01bd3e92ecbb79070272a9d4964dce6a48a82ff0b8bc7e683a" ], - "version": "==2.4.2" + "version": "==2.4.5" }, "python-dateutil": { "hashes": [ - "sha256:7e6584c74aeed623791615e26efd690f29817a27c73085b78e4bad02493df2fb", - "sha256:c89805f6f4d64db21ed966fda138f8a5ed7a4fdbc1a8ee329ce1b74e3c74da9e" + "sha256:73ebfe9dbf22e832286dafa60473e4cd239f8592f699aa5adaf10050e6e1823c", + "sha256:75bb3f31ea686f1197762692a9ee6a7550b59fc6ca3a1f4b5d7e32fb98e2da2a" ], "index": "pypi", - "version": "==2.8.0" + "version": "==2.8.1" }, "python-json-logger": { "hashes": [ @@ -434,10 +443,10 @@ }, "six": { "hashes": [ - "sha256:3350809f0555b11f552448330d0b52d5f24c91a322ea4a15ef22629740f3761c", - "sha256:d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73" + "sha256:1f1b7d42e254082a9db6279deae68afb421ceba6158efa6131de7b3003ee93fd", + "sha256:30f610279e8b2578cab6db20741130331735c781b56053c59c4076da27f06b66" ], - "version": "==1.12.0" + "version": "==1.13.0" }, "snowballstemmer": { "hashes": [ @@ -448,18 +457,18 @@ }, "soupsieve": { "hashes": [ - "sha256:605f89ad5fdbfefe30cdc293303665eff2d188865d4dbe4eb510bba1edfbfce3", - "sha256:b91d676b330a0ebd5b21719cb6e9b57c57d433671f65b9c28dd3461d9a1ed0b6" + "sha256:bdb0d917b03a1369ce964056fc195cfdff8819c40de04695a80bc813c3cfa1f5", + "sha256:e2c1c5dee4a1c36bcb790e0fabd5492d874b8ebd4617622c4f6a731701060dda" ], - "version": "==1.9.4" + "version": "==1.9.5" }, "sphinx": { "hashes": [ - "sha256:0d586b0f8c2fc3cc6559c5e8fd6124628110514fda0e5d7c82e682d749d2e845", - "sha256:839a3ed6f6b092bb60f492024489cc9e6991360fb9f52ed6361acd510d261069" + "sha256:31088dfb95359384b1005619827eaee3056243798c62724fd3fa4b84ee4d71bd", + "sha256:52286a0b9d7caa31efee301ec4300dbdab23c3b05da1c9024b4e84896fb73d79" ], "index": "pypi", - "version": "==2.2.0" + "version": "==2.2.1" }, "sphinxcontrib-applehelp": { "hashes": [ @@ -564,10 +573,10 @@ }, "attrs": { "hashes": [ - "sha256:ec20e7a4825331c1b5ebf261d111e16fa9612c1f7a5e1f884f12bd53a664dfd2", - "sha256:f913492e1663d3c36f502e5e9ba6cd13cf19d7fab50aa13239e420fef95e1396" + "sha256:08a96c641c3a74e44eb59afb61a24f2cb9f4d7188748e76ba4bb5edfa3cb7d1c", + "sha256:f7b7ce16570fe9965acd6d30101a28f62fb4a7f9e926b3bbc9b61f8b04247e72" ], - "version": "==19.2.0" + "version": "==19.3.0" }, "certifi": { "hashes": [ @@ -658,11 +667,11 @@ }, "flake8": { "hashes": [ - "sha256:19241c1cbc971b9962473e4438a2ca19749a7dd002dd1a946eaba171b4114548", - "sha256:8e9dfa3cecb2400b3738a42c54c3043e821682b9c840b0448c0503f781130696" + "sha256:45681a117ecc81e870cbf1262835ae4af5e7a8b08e40b944a8a6e6b895914cfb", + "sha256:49356e766643ad15072a789a20915d3c91dc89fd313ccd71802303fd67e4deca" ], "index": "pypi", - "version": "==3.7.8" + "version": "==3.7.9" }, "flake8-annotations": { "hashes": [ @@ -738,6 +747,7 @@ "sha256:aa18d7378b00b40847790e7c27e11673d7fed219354109d0e7b9e5b25dc3ad26", "sha256:d5f18a79777f3aa179c145737780282e27b508fc8fd688cb17c7a813e8bd39af" ], + "markers": "python_version < '3.8'", "version": "==0.23" }, "mccabe": { @@ -770,11 +780,11 @@ }, "pre-commit": { "hashes": [ - "sha256:1d3c0587bda7c4e537a46c27f2c84aa006acc18facf9970bf947df596ce91f3f", - "sha256:fa78ff96e8e9ac94c748388597693f18b041a181c94a4f039ad20f45287ba44a" + "sha256:9f152687127ec90642a2cc3e4d9e1e6240c4eb153615cb02aa1ad41d331cbb6e", + "sha256:c2e4810d2d3102d354947907514a78c5d30424d299dc0fe48f5aa049826e9b50" ], "index": "pypi", - "version": "==1.18.3" + "version": "==1.20.0" }, "pycodestyle": { "hashes": [ @@ -799,10 +809,10 @@ }, "pyparsing": { "hashes": [ - "sha256:6f98a7b9397e206d78cc01df10131398f1c8b8510a2f4d97d9abd82e1aacdd80", - "sha256:d9338df12903bbf5d65a0e4e87c2161968b10d2e489652bb47001d82a9b028b4" + "sha256:20f995ecd72f2a1f4bf6b072b63b22e2eb457836601e76d6e5dfcd75436acc1f", + "sha256:4ca62001be367f01bd3e92ecbb79070272a9d4964dce6a48a82ff0b8bc7e683a" ], - "version": "==2.4.2" + "version": "==2.4.5" }, "pyyaml": { "hashes": [ @@ -841,10 +851,10 @@ }, "six": { "hashes": [ - "sha256:3350809f0555b11f552448330d0b52d5f24c91a322ea4a15ef22629740f3761c", - "sha256:d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73" + "sha256:1f1b7d42e254082a9db6279deae68afb421ceba6158efa6131de7b3003ee93fd", + "sha256:30f610279e8b2578cab6db20741130331735c781b56053c59c4076da27f06b66" ], - "version": "==1.12.0" + "version": "==1.13.0" }, "snowballstemmer": { "hashes": [ @@ -862,31 +872,36 @@ }, "typed-ast": { "hashes": [ + "sha256:1170afa46a3799e18b4c977777ce137bb53c7485379d9706af8a59f2ea1aa161", "sha256:18511a0b3e7922276346bcb47e2ef9f38fb90fd31cb9223eed42c85d1312344e", "sha256:262c247a82d005e43b5b7f69aff746370538e176131c32dda9cb0f324d27141e", "sha256:2b907eb046d049bcd9892e3076c7a6456c93a25bebfe554e931620c90e6a25b0", "sha256:354c16e5babd09f5cb0ee000d54cfa38401d8b8891eefa878ac772f827181a3c", + "sha256:48e5b1e71f25cfdef98b013263a88d7145879fbb2d5185f2a0c79fa7ebbeae47", "sha256:4e0b70c6fc4d010f8107726af5fd37921b666f5b31d9331f0bd24ad9a088e631", "sha256:630968c5cdee51a11c05a30453f8cd65e0cc1d2ad0d9192819df9978984529f4", "sha256:66480f95b8167c9c5c5c87f32cf437d585937970f3fc24386f313a4c97b44e34", "sha256:71211d26ffd12d63a83e079ff258ac9d56a1376a25bc80b1cdcdf601b855b90b", + "sha256:7954560051331d003b4e2b3eb822d9dd2e376fa4f6d98fee32f452f52dd6ebb2", + "sha256:838997f4310012cf2e1ad3803bce2f3402e9ffb71ded61b5ee22617b3a7f6b6e", "sha256:95bd11af7eafc16e829af2d3df510cecfd4387f6453355188342c3e79a2ec87a", "sha256:bc6c7d3fa1325a0c6613512a093bc2a2a15aeec350451cbdf9e1d4bffe3e3233", "sha256:cc34a6f5b426748a507dd5d1de4c1978f2eb5626d51326e43280941206c209e1", "sha256:d755f03c1e4a51e9b24d899561fec4ccaf51f210d52abdf8c07ee2849b212a36", "sha256:d7c45933b1bdfaf9f36c579671fec15d25b06c8398f113dab64c18ed1adda01d", "sha256:d896919306dd0aa22d0132f62a1b78d11aaf4c9fc5b3410d3c666b818191630a", + "sha256:fdc1c9bbf79510b76408840e009ed65958feba92a88833cdceecff93ae8fff66", "sha256:ffde2fbfad571af120fcbfbbc61c72469e72f550d676c3342492a9dfdefb8f12" ], "version": "==1.4.0" }, "unittest-xml-reporting": { "hashes": [ - "sha256:140982e4b58e4052d9ecb775525b246a96bfc1fc26097806e05ea06e9166dd6c", - "sha256:d1fbc7a1b6c6680ccfe75b5e9701e5431c646970de049e687b4bb35ba4325d72" + "sha256:358bbdaf24a26d904cc1c26ef3078bca7fc81541e0a54c8961693cc96a6f35e0", + "sha256:9d28ddf6524cf0ff9293f61bd12e792de298f8561a5c945acea63fb437789e0e" ], "index": "pypi", - "version": "==2.5.1" + "version": "==2.5.2" }, "urllib3": { "hashes": [ @@ -898,10 +913,10 @@ }, "virtualenv": { "hashes": [ - "sha256:680af46846662bb38c5504b78bad9ed9e4f3ba2d54f54ba42494fdf94337fe30", - "sha256:f78d81b62d3147396ac33fc9d77579ddc42cc2a98dd9ea38886f616b33bc7fb2" + "sha256:11cb4608930d5fd3afb545ecf8db83fa50e1f96fc4fca80c94b07d2c83146589", + "sha256:d257bb3773e48cac60e475a19b608996c73f4d333b3ba2e4e57d5ac6134e0136" ], - "version": "==16.7.5" + "version": "==16.7.7" }, "zipp": { "hashes": [ diff --git a/azure-pipelines.yml b/azure-pipelines.yml index da3b06201..0400ac4d2 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -30,7 +30,7 @@ jobs: - script: python -m flake8 displayName: 'Run linter' - - script: BOT_API_KEY=foo BOT_TOKEN=bar WOLFRAM_API_KEY=baz coverage run -m xmlrunner + - script: BOT_API_KEY=foo BOT_TOKEN=bar WOLFRAM_API_KEY=baz REDDIT_CLIENT_ID=spam REDDIT_SECRET=ham coverage run -m xmlrunner displayName: Run tests - script: coverage report -m && coverage xml -o coverage.xml diff --git a/bot/__main__.py b/bot/__main__.py index f352cd60e..84bc7094b 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -1,18 +1,11 @@ -import asyncio -import logging -import socket - import discord -from aiohttp import AsyncResolver, ClientSession, TCPConnector -from discord.ext.commands import Bot, when_mentioned_or +from discord.ext.commands import when_mentioned_or from bot import patches -from bot.api import APIClient, APILoggingHandler +from bot.bot import Bot from bot.constants import Bot as BotConfig, DEBUG_MODE -log = logging.getLogger('bot') - bot = Bot( command_prefix=when_mentioned_or(BotConfig.prefix), activity=discord.Game(name="Commands: !help"), @@ -20,18 +13,6 @@ bot = Bot( max_messages=10_000, ) -# Global aiohttp session for all cogs -# - Uses asyncio for DNS resolution instead of threads, so we don't spam threads -# - Uses AF_INET as its socket family to prevent https related problems both locally and in prod. -bot.http_session = ClientSession( - connector=TCPConnector( - resolver=AsyncResolver(), - family=socket.AF_INET, - ) -) -bot.api_client = APIClient(loop=asyncio.get_event_loop()) -log.addHandler(APILoggingHandler(bot.api_client)) - # Internal/debug bot.load_extension("bot.cogs.error_handler") bot.load_extension("bot.cogs.filtering") @@ -55,6 +36,7 @@ if not DEBUG_MODE: bot.load_extension("bot.cogs.alias") bot.load_extension("bot.cogs.defcon") bot.load_extension("bot.cogs.eval") +bot.load_extension("bot.cogs.duck_pond") bot.load_extension("bot.cogs.free") bot.load_extension("bot.cogs.information") bot.load_extension("bot.cogs.jams") @@ -76,6 +58,3 @@ if not hasattr(discord.message.Message, '_handle_edited_timestamp'): patches.message_edited_at.apply_patch() bot.run(BotConfig.token) - -# This calls a coroutine, so it doesn't do anything at the moment. -# bot.http_session.close() # Close the aiohttp session when the bot finishes running diff --git a/bot/api.py b/bot/api.py index 7f26e5305..56db99828 100644 --- a/bot/api.py +++ b/bot/api.py @@ -32,7 +32,7 @@ class ResponseCodeError(ValueError): class APIClient: """Django Site API wrapper.""" - def __init__(self, **kwargs): + def __init__(self, loop: asyncio.AbstractEventLoop, **kwargs): auth_headers = { 'Authorization': f"Token {Keys.site_api}" } @@ -42,12 +42,39 @@ class APIClient: else: kwargs['headers'] = auth_headers - self.session = aiohttp.ClientSession(**kwargs) + self.session: Optional[aiohttp.ClientSession] = None + self.loop = loop + + self._ready = asyncio.Event(loop=loop) + self._creation_task = None + self._session_args = kwargs + + self.recreate() @staticmethod def _url_for(endpoint: str) -> str: return f"{URLs.site_schema}{URLs.site_api}/{quote_url(endpoint)}" + async def _create_session(self) -> None: + """Create the aiohttp session and set the ready event.""" + self.session = aiohttp.ClientSession(**self._session_args) + self._ready.set() + + async def close(self) -> None: + """Close the aiohttp session and unset the ready event.""" + if not self._ready.is_set(): + return + + await self.session.close() + self._ready.clear() + + def recreate(self) -> None: + """Schedule the aiohttp session to be created if it's been closed.""" + if self.session is None or self.session.closed: + # Don't schedule a task if one is already in progress. + if self._creation_task is None or self._creation_task.done(): + self._creation_task = self.loop.create_task(self._create_session()) + async def maybe_raise_for_status(self, response: aiohttp.ClientResponse, should_raise: bool) -> None: """Raise ResponseCodeError for non-OK response if an exception should be raised.""" if should_raise and response.status >= 400: @@ -60,30 +87,40 @@ class APIClient: async def get(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict: """Site API GET.""" + await self._ready.wait() + async with self.session.get(self._url_for(endpoint), *args, **kwargs) as resp: await self.maybe_raise_for_status(resp, raise_for_status) return await resp.json() async def patch(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict: """Site API PATCH.""" + await self._ready.wait() + async with self.session.patch(self._url_for(endpoint), *args, **kwargs) as resp: await self.maybe_raise_for_status(resp, raise_for_status) return await resp.json() async def post(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict: """Site API POST.""" + await self._ready.wait() + async with self.session.post(self._url_for(endpoint), *args, **kwargs) as resp: await self.maybe_raise_for_status(resp, raise_for_status) return await resp.json() async def put(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict: """Site API PUT.""" + await self._ready.wait() + async with self.session.put(self._url_for(endpoint), *args, **kwargs) as resp: await self.maybe_raise_for_status(resp, raise_for_status) return await resp.json() async def delete(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> Optional[dict]: """Site API DELETE.""" + await self._ready.wait() + async with self.session.delete(self._url_for(endpoint), *args, **kwargs) as resp: if resp.status == 204: return None diff --git a/bot/bot.py b/bot/bot.py new file mode 100644 index 000000000..8f808272f --- /dev/null +++ b/bot/bot.py @@ -0,0 +1,53 @@ +import logging +import socket +from typing import Optional + +import aiohttp +from discord.ext import commands + +from bot import api + +log = logging.getLogger('bot') + + +class Bot(commands.Bot): + """A subclass of `discord.ext.commands.Bot` with an aiohttp session and an API client.""" + + def __init__(self, *args, **kwargs): + # Use asyncio for DNS resolution instead of threads so threads aren't spammed. + # Use AF_INET as its socket family to prevent HTTPS related problems both locally + # and in production. + self.connector = aiohttp.TCPConnector( + resolver=aiohttp.AsyncResolver(), + family=socket.AF_INET, + ) + + super().__init__(*args, connector=self.connector, **kwargs) + + self.http_session: Optional[aiohttp.ClientSession] = None + self.api_client = api.APIClient(loop=self.loop, connector=self.connector) + + log.addHandler(api.APILoggingHandler(self.api_client)) + + def add_cog(self, cog: commands.Cog) -> None: + """Adds a "cog" to the bot and logs the operation.""" + super().add_cog(cog) + log.info(f"Cog loaded: {cog.qualified_name}") + + def clear(self) -> None: + """Clears the internal state of the bot and resets the API client.""" + super().clear() + self.api_client.recreate() + + async def close(self) -> None: + """Close the aiohttp session after closing the Discord connection.""" + await super().close() + + await self.http_session.close() + await self.api_client.close() + + async def start(self, *args, **kwargs) -> None: + """Open an aiohttp session before logging in and connecting to Discord.""" + self.http_session = aiohttp.ClientSession(connector=self.connector) + + await super().start(*args, **kwargs) diff --git a/bot/cogs/alias.py b/bot/cogs/alias.py index 5190c559b..c1db38462 100644 --- a/bot/cogs/alias.py +++ b/bot/cogs/alias.py @@ -3,8 +3,9 @@ import logging from typing import Union from discord import Colour, Embed, Member, User -from discord.ext.commands import Bot, Cog, Command, Context, clean_content, command, group +from discord.ext.commands import Cog, Command, Context, clean_content, command, group +from bot.bot import Bot from bot.cogs.extensions import Extension from bot.cogs.watchchannels.watchchannel import proxy_user from bot.converters import TagNameConverter @@ -147,6 +148,5 @@ class Alias (Cog): def setup(bot: Bot) -> None: - """Alias cog load.""" + """Load the Alias cog.""" bot.add_cog(Alias(bot)) - log.info("Cog loaded: Alias") diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py index ababd6f18..28e3e5d96 100644 --- a/bot/cogs/antimalware.py +++ b/bot/cogs/antimalware.py @@ -1,9 +1,10 @@ import logging -from discord import Message, NotFound -from discord.ext.commands import Bot, Cog +from discord import Embed, Message, NotFound +from discord.ext.commands import Cog -from bot.constants import AntiMalware as AntiMalwareConfig, Channels +from bot.bot import Bot +from bot.constants import AntiMalware as AntiMalwareConfig, Channels, URLs log = logging.getLogger(__name__) @@ -17,31 +18,29 @@ class AntiMalware(Cog): @Cog.listener() async def on_message(self, message: Message) -> None: """Identify messages with prohibited attachments.""" - rejected_attachments = False - detected_pyfile = False + if not message.attachments: + return + + embed = Embed() for attachment in message.attachments: - if attachment.filename.lower().endswith('.py'): - detected_pyfile = True - break # Other detections irrelevant because we prioritize the .py message. - if not attachment.filename.lower().endswith(tuple(AntiMalwareConfig.whitelist)): - rejected_attachments = True - - if detected_pyfile or rejected_attachments: - # Send a message to the user indicating the problem (with special treatment for .py) - author = message.author - if detected_pyfile: - msg = ( - f"{author.mention}, it looks like you tried to attach a Python file - please " - f"use a code-pasting service such as https://paste.pythondiscord.com/ instead." + filename = attachment.filename.lower() + if filename.endswith('.py'): + embed.description = ( + f"It looks like you tried to attach a Python file - please " + f"use a code-pasting service such as {URLs.site_schema}{URLs.site_paste}" ) - else: + break # Other detections irrelevant because we prioritize the .py message. + if not filename.endswith(tuple(AntiMalwareConfig.whitelist)): + whitelisted_types = ', '.join(AntiMalwareConfig.whitelist) meta_channel = self.bot.get_channel(Channels.meta) - msg = ( - f"{author.mention}, it looks like you tried to attach a file type we don't " - f"allow. Feel free to ask in {meta_channel.mention} if you think this is a mistake." + embed.description = ( + f"It looks like you tried to attach a file type that we " + f"do not allow. We currently allow the following file " + f"types: **{whitelisted_types}**. \n\n Feel free to ask " + f"in {meta_channel.mention} if you think this is a mistake." ) - - await message.channel.send(msg) + if embed.description: + await message.channel.send(f"Hey {message.author.mention}!", embed=embed) # Delete the offending message: try: @@ -51,6 +50,5 @@ class AntiMalware(Cog): def setup(bot: Bot) -> None: - """Antimalware cog load.""" + """Load the AntiMalware cog.""" bot.add_cog(AntiMalware(bot)) - log.info("Cog loaded: AntiMalware") diff --git a/bot/cogs/antispam.py b/bot/cogs/antispam.py index 1340eb608..f454061a6 100644 --- a/bot/cogs/antispam.py +++ b/bot/cogs/antispam.py @@ -7,9 +7,10 @@ from operator import itemgetter from typing import Dict, Iterable, List, Set from discord import Colour, Member, Message, NotFound, Object, TextChannel -from discord.ext.commands import Bot, Cog +from discord.ext.commands import Cog from bot import rules +from bot.bot import Bot from bot.cogs.moderation import ModLog from bot.constants import ( AntiSpam as AntiSpamConfig, Channels, @@ -276,7 +277,6 @@ def validate_config(rules: Mapping = AntiSpamConfig.rules) -> Dict[str, str]: def setup(bot: Bot) -> None: - """Antispam cog load.""" + """Validate the AntiSpam configs and load the AntiSpam cog.""" validation_errors = validate_config() bot.add_cog(AntiSpam(bot, validation_errors)) - log.info("Cog loaded: AntiSpam") diff --git a/bot/cogs/bot.py b/bot/cogs/bot.py index 7583b2f2d..73b1e8f41 100644 --- a/bot/cogs/bot.py +++ b/bot/cogs/bot.py @@ -4,9 +4,11 @@ import re import time from typing import Optional, Tuple -from discord import Embed, Message, RawMessageUpdateEvent -from discord.ext.commands import Bot, Cog, Context, command, group +from discord import Embed, Message, RawMessageUpdateEvent, TextChannel +from discord.ext.commands import Cog, Context, command, group +from bot.bot import Bot +from bot.cogs.token_remover import TokenRemover from bot.constants import Channels, DEBUG_MODE, Guild, MODERATION_ROLES, Roles, URLs from bot.decorators import with_role from bot.utils.messages import wait_for_deletion @@ -16,7 +18,7 @@ log = logging.getLogger(__name__) RE_MARKDOWN = re.compile(r'([*_~`|>])') -class Bot(Cog): +class BotCog(Cog, name="Bot"): """Bot information commands.""" def __init__(self, bot: Bot): @@ -71,9 +73,12 @@ class Bot(Cog): @command(name='echo', aliases=('print',)) @with_role(*MODERATION_ROLES) - async def echo_command(self, ctx: Context, *, text: str) -> None: - """Send the input verbatim to the current channel.""" - await ctx.send(text) + async def echo_command(self, ctx: Context, channel: Optional[TextChannel], *, text: str) -> None: + """Repeat the given message in either a specified channel or the current channel.""" + if channel is None: + await ctx.send(text) + else: + await channel.send(text) @command(name='embed') @with_role(*MODERATION_ROLES) @@ -235,9 +240,10 @@ class Bot(Cog): ) and not msg.author.bot and len(msg.content.splitlines()) > 3 + and not TokenRemover.is_token_in_message(msg) ) - if parse_codeblock: + if parse_codeblock: # no token in the msg on_cooldown = (time.time() - self.channel_cooldowns.get(msg.channel.id, 0)) < 300 if not on_cooldown or DEBUG_MODE: try: @@ -370,10 +376,9 @@ class Bot(Cog): bot_message = await channel.fetch_message(self.codeblock_message_ids[payload.message_id]) await bot_message.delete() del self.codeblock_message_ids[payload.message_id] - log.trace("User's incorrect code block has been fixed. Removing bot formatting message.") + log.trace("User's incorrect code block has been fixed. Removing bot formatting message.") def setup(bot: Bot) -> None: - """Bot cog load.""" - bot.add_cog(Bot(bot)) - log.info("Cog loaded: Bot") + """Load the Bot cog.""" + bot.add_cog(BotCog(bot)) diff --git a/bot/cogs/clean.py b/bot/cogs/clean.py index dca411d01..2104efe57 100644 --- a/bot/cogs/clean.py +++ b/bot/cogs/clean.py @@ -3,9 +3,10 @@ import random import re from typing import Optional -from discord import Colour, Embed, Message, User -from discord.ext.commands import Bot, Cog, Context, group +from discord import Colour, Embed, Message, TextChannel, User +from discord.ext.commands import Cog, Context, group +from bot.bot import Bot from bot.cogs.moderation import ModLog from bot.constants import ( Channels, CleanMessages, Colours, Event, @@ -37,9 +38,13 @@ class Clean(Cog): return self.bot.get_cog("ModLog") async def _clean_messages( - self, amount: int, ctx: Context, - bots_only: bool = False, user: User = None, - regex: Optional[str] = None + self, + amount: int, + ctx: Context, + bots_only: bool = False, + user: User = None, + regex: Optional[str] = None, + channel: Optional[TextChannel] = None ) -> None: """A helper function that does the actual message cleaning.""" def predicate_bots_only(message: Message) -> bool: @@ -104,6 +109,10 @@ class Clean(Cog): else: predicate = None # Delete all messages + # Default to using the invoking context's channel + if not channel: + channel = ctx.channel + # Look through the history and retrieve message data messages = [] message_ids = [] @@ -111,7 +120,7 @@ class Clean(Cog): invocation_deleted = False # To account for the invocation message, we index `amount + 1` messages. - async for message in ctx.channel.history(limit=amount + 1): + async for message in channel.history(limit=amount + 1): # If at any point the cancel command is invoked, we should stop. if not self.cleaning: @@ -135,7 +144,7 @@ class Clean(Cog): self.mod_log.ignore(Event.message_delete, *message_ids) # Use bulk delete to actually do the cleaning. It's far faster. - await ctx.channel.purge( + await channel.purge( limit=amount, check=predicate ) @@ -155,7 +164,7 @@ class Clean(Cog): # Build the embed and send it message = ( - f"**{len(message_ids)}** messages deleted in <#{ctx.channel.id}> by **{ctx.author.name}**\n\n" + f"**{len(message_ids)}** messages deleted in <#{channel.id}> by **{ctx.author.name}**\n\n" f"A log of the deleted messages can be found [here]({log_url})." ) @@ -167,7 +176,7 @@ class Clean(Cog): channel_id=Channels.modlog, ) - @group(invoke_without_command=True, name="clean", hidden=True) + @group(invoke_without_command=True, name="clean", aliases=["purge"]) @with_role(*MODERATION_ROLES) async def clean_group(self, ctx: Context) -> None: """Commands for cleaning messages in channels.""" @@ -175,27 +184,49 @@ class Clean(Cog): @clean_group.command(name="user", aliases=["users"]) @with_role(*MODERATION_ROLES) - async def clean_user(self, ctx: Context, user: User, amount: int = 10) -> None: + async def clean_user( + self, + ctx: Context, + user: User, + amount: Optional[int] = 10, + channel: TextChannel = None + ) -> None: """Delete messages posted by the provided user, stop cleaning after traversing `amount` messages.""" - await self._clean_messages(amount, ctx, user=user) + await self._clean_messages(amount, ctx, user=user, channel=channel) @clean_group.command(name="all", aliases=["everything"]) @with_role(*MODERATION_ROLES) - async def clean_all(self, ctx: Context, amount: int = 10) -> None: + async def clean_all( + self, + ctx: Context, + amount: Optional[int] = 10, + channel: TextChannel = None + ) -> None: """Delete all messages, regardless of poster, stop cleaning after traversing `amount` messages.""" - await self._clean_messages(amount, ctx) + await self._clean_messages(amount, ctx, channel=channel) @clean_group.command(name="bots", aliases=["bot"]) @with_role(*MODERATION_ROLES) - async def clean_bots(self, ctx: Context, amount: int = 10) -> None: + async def clean_bots( + self, + ctx: Context, + amount: Optional[int] = 10, + channel: TextChannel = None + ) -> None: """Delete all messages posted by a bot, stop cleaning after traversing `amount` messages.""" - await self._clean_messages(amount, ctx, bots_only=True) + await self._clean_messages(amount, ctx, bots_only=True, channel=channel) @clean_group.command(name="regex", aliases=["word", "expression"]) @with_role(*MODERATION_ROLES) - async def clean_regex(self, ctx: Context, regex: str, amount: int = 10) -> None: + async def clean_regex( + self, + ctx: Context, + regex: str, + amount: Optional[int] = 10, + channel: TextChannel = None + ) -> None: """Delete all messages that match a certain regex, stop cleaning after traversing `amount` messages.""" - await self._clean_messages(amount, ctx, regex=regex) + await self._clean_messages(amount, ctx, regex=regex, channel=channel) @clean_group.command(name="stop", aliases=["cancel", "abort"]) @with_role(*MODERATION_ROLES) @@ -211,6 +242,5 @@ class Clean(Cog): def setup(bot: Bot) -> None: - """Clean cog load.""" + """Load the Clean cog.""" bot.add_cog(Clean(bot)) - log.info("Cog loaded: Clean") diff --git a/bot/cogs/defcon.py b/bot/cogs/defcon.py index bedd70c86..3e7350fcc 100644 --- a/bot/cogs/defcon.py +++ b/bot/cogs/defcon.py @@ -6,8 +6,9 @@ from datetime import datetime, timedelta from enum import Enum from discord import Colour, Embed, Member -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import Cog, Context, group +from bot.bot import Bot from bot.cogs.moderation import ModLog from bot.constants import Channels, Colours, Emojis, Event, Icons, Roles from bot.decorators import with_role @@ -236,6 +237,5 @@ class Defcon(Cog): def setup(bot: Bot) -> None: - """DEFCON cog load.""" + """Load the Defcon cog.""" bot.add_cog(Defcon(bot)) - log.info("Cog loaded: Defcon") diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index 65cabe46f..9506b195a 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -4,17 +4,21 @@ import logging import re import textwrap from collections import OrderedDict +from contextlib import suppress from typing import Any, Callable, Optional, Tuple import discord from bs4 import BeautifulSoup -from bs4.element import PageElement +from bs4.element import PageElement, Tag +from discord.errors import NotFound from discord.ext import commands from markdownify import MarkdownConverter -from requests import ConnectionError +from requests import ConnectTimeout, ConnectionError, HTTPError from sphinx.ext import intersphinx +from urllib3.exceptions import ProtocolError -from bot.constants import MODERATION_ROLES +from bot.bot import Bot +from bot.constants import MODERATION_ROLES, RedirectOutput from bot.converters import ValidPythonIdentifier, ValidURL from bot.decorators import with_role from bot.pagination import LinePaginator @@ -23,10 +27,33 @@ from bot.pagination import LinePaginator log = logging.getLogger(__name__) logging.getLogger('urllib3').setLevel(logging.WARNING) - -UNWANTED_SIGNATURE_SYMBOLS = ('[source]', '¶') +NO_OVERRIDE_GROUPS = ( + "2to3fixer", + "token", + "label", + "pdbcommand", + "term", +) +NO_OVERRIDE_PACKAGES = ( + "python", +) + +SEARCH_END_TAG_ATTRS = ( + "data", + "function", + "class", + "exception", + "seealso", + "section", + "rubric", + "sphinxsidebar", +) +UNWANTED_SIGNATURE_SYMBOLS_RE = re.compile(r"\[source]|\\\\|¶") WHITESPACE_AFTER_NEWLINES_RE = re.compile(r"(?<=\n\n)(\s+)") +FAILED_REQUEST_RETRY_AMOUNT = 3 +NOT_FOUND_DELETE_DELAY = RedirectOutput.delete_delay + def async_cache(max_size: int = 128, arg_offset: int = 0) -> Callable: """ @@ -121,10 +148,11 @@ class InventoryURL(commands.Converter): class Doc(commands.Cog): """A set of commands for querying & displaying documentation.""" - def __init__(self, bot: commands.Bot): + def __init__(self, bot: Bot): self.base_urls = {} self.bot = bot self.inventories = {} + self.renamed_symbols = set() self.bot.loop.create_task(self.init_refresh_inventory()) @@ -150,13 +178,32 @@ class Doc(commands.Cog): """ self.base_urls[package_name] = base_url - fetch_func = functools.partial(intersphinx.fetch_inventory, config, '', inventory_url) - for _, value in (await self.bot.loop.run_in_executor(None, fetch_func)).items(): - # Each value has a bunch of information in the form - # `(package_name, version, relative_url, ???)`, and we only - # need the relative documentation URL. - for symbol, (_, _, relative_doc_url, _) in value.items(): + package = await self._fetch_inventory(inventory_url, config) + if not package: + return None + + for group, value in package.items(): + for symbol, (package_name, _version, relative_doc_url, _) in value.items(): absolute_doc_url = base_url + relative_doc_url + + if symbol in self.inventories: + group_name = group.split(":")[1] + symbol_base_url = self.inventories[symbol].split("/", 3)[2] + if ( + group_name in NO_OVERRIDE_GROUPS + or any(package in symbol_base_url for package in NO_OVERRIDE_PACKAGES) + ): + + symbol = f"{group_name}.{symbol}" + # If renamed `symbol` already exists, add library name in front to differentiate between them. + if symbol in self.renamed_symbols: + # Split `package_name` because of packages like Pillow that have spaces in them. + symbol = f"{package_name.split()[0]}.{symbol}" + + self.inventories[symbol] = absolute_doc_url + self.renamed_symbols.add(symbol) + continue + self.inventories[symbol] = absolute_doc_url log.trace(f"Fetched inventory for {package_name}.") @@ -170,6 +217,7 @@ class Doc(commands.Cog): # Also, reset the cache used for fetching documentation. self.base_urls.clear() self.inventories.clear() + self.renamed_symbols.clear() async_cache.cache = OrderedDict() # Since Intersphinx is intended to be used with Sphinx, @@ -185,16 +233,15 @@ class Doc(commands.Cog): ] await asyncio.gather(*coros) - async def get_symbol_html(self, symbol: str) -> Optional[Tuple[str, str]]: + async def get_symbol_html(self, symbol: str) -> Optional[Tuple[list, str]]: """ Given a Python symbol, return its signature and description. - Returns a tuple in the form (str, str), or `None`. - The first tuple element is the signature of the given symbol as a markup-free string, and the second tuple element is the description of the given symbol with HTML markup included. - If the given symbol could not be found, returns `None`. + If the given symbol is a module, returns a tuple `(None, str)` + else if the symbol could not be found, returns `None`. """ url = self.inventories.get(symbol) if url is None: @@ -207,21 +254,38 @@ class Doc(commands.Cog): symbol_id = url.split('#')[-1] soup = BeautifulSoup(html, 'lxml') symbol_heading = soup.find(id=symbol_id) - signature_buffer = [] + search_html = str(soup) if symbol_heading is None: return None - # Traverse the tags of the signature header and ignore any - # unwanted symbols from it. Add all of it to a temporary buffer. - for tag in symbol_heading.strings: - if tag not in UNWANTED_SIGNATURE_SYMBOLS: - signature_buffer.append(tag.replace('\\', '')) + if symbol_id == f"module-{symbol}": + # Get page content from the module headerlink to the + # first tag that has its class in `SEARCH_END_TAG_ATTRS` + start_tag = symbol_heading.find("a", attrs={"class": "headerlink"}) + if start_tag is None: + return [], "" + + end_tag = start_tag.find_next(self._match_end_tag) + if end_tag is None: + return [], "" + + description_start_index = search_html.find(str(start_tag.parent)) + len(str(start_tag.parent)) + description_end_index = search_html.find(str(end_tag)) + description = search_html[description_start_index:description_end_index] + signatures = None - signature = ''.join(signature_buffer) - description = str(symbol_heading.next_sibling.next_sibling).replace('¶', '') + else: + signatures = [] + description = str(symbol_heading.find_next_sibling("dd")) + description_pos = search_html.find(description) + # Get text of up to 3 signatures, remove unwanted symbols + for element in [symbol_heading] + symbol_heading.find_next_siblings("dt", limit=2): + signature = UNWANTED_SIGNATURE_SYMBOLS_RE.sub("", element.text) + if signature and search_html.find(str(element)) < description_pos: + signatures.append(signature) - return signature, description + return signatures, description.replace('¶', '') @async_cache(arg_offset=1) async def get_symbol_embed(self, symbol: str) -> Optional[discord.Embed]: @@ -234,7 +298,7 @@ class Doc(commands.Cog): if scraped_html is None: return None - signature = scraped_html[0] + signatures = scraped_html[0] permalink = self.inventories[symbol] description = markdownify(scraped_html[1]) @@ -242,26 +306,42 @@ class Doc(commands.Cog): # of a double newline (interpreted as a paragraph) before index 1000. if len(description) > 1000: shortened = description[:1000] - last_paragraph_end = shortened.rfind('\n\n') - description = description[:last_paragraph_end] + f"... [read more]({permalink})" + last_paragraph_end = shortened.rfind('\n\n', 100) + if last_paragraph_end == -1: + last_paragraph_end = shortened.rfind('. ') + description = description[:last_paragraph_end] + + # If there is an incomplete code block, cut it out + if description.count("```") % 2: + codeblock_start = description.rfind('```py') + description = description[:codeblock_start].rstrip() + description += f"... [read more]({permalink})" description = WHITESPACE_AFTER_NEWLINES_RE.sub('', description) - if not signature: + if signatures is None: + # If symbol is a module, don't show signature. + embed_description = description + + elif not signatures: # It's some "meta-page", for example: # https://docs.djangoproject.com/en/dev/ref/views/#module-django.views - return discord.Embed( - title=f'`{symbol}`', - url=permalink, - description="This appears to be a generic page not tied to a specific symbol." - ) + embed_description = "This appears to be a generic page not tied to a specific symbol." - signature = textwrap.shorten(signature, 500) - return discord.Embed( + else: + embed_description = "".join(f"```py\n{textwrap.shorten(signature, 500)}```" for signature in signatures) + embed_description += f"\n{description}" + + embed = discord.Embed( title=f'`{symbol}`', url=permalink, - description=f"```py\n{signature}```{description}" + description=embed_description ) + # Show all symbols with the same name that were renamed in the footer. + embed.set_footer( + text=", ".join(renamed for renamed in self.renamed_symbols - {symbol} if renamed.endswith(f".{symbol}")) + ) + return embed @commands.group(name='docs', aliases=('doc', 'd'), invoke_without_command=True) async def docs_group(self, ctx: commands.Context, symbol: commands.clean_content = None) -> None: @@ -307,7 +387,10 @@ class Doc(commands.Cog): description=f"Sorry, I could not find any documentation for `{symbol}`.", colour=discord.Colour.red() ) - await ctx.send(embed=error_embed) + error_message = await ctx.send(embed=error_embed) + with suppress(NotFound): + await error_message.delete(delay=NOT_FOUND_DELETE_DELAY) + await ctx.message.delete(delay=NOT_FOUND_DELETE_DELAY) else: await ctx.send(embed=doc_embed) @@ -365,8 +448,65 @@ class Doc(commands.Cog): await self.refresh_inventory() await ctx.send(f"Successfully deleted `{package_name}` and refreshed inventory.") + @docs_group.command(name="refresh", aliases=("rfsh", "r")) + @with_role(*MODERATION_ROLES) + async def refresh_command(self, ctx: commands.Context) -> None: + """Refresh inventories and send differences to channel.""" + old_inventories = set(self.base_urls) + with ctx.typing(): + await self.refresh_inventory() + # Get differences of added and removed inventories + added = ', '.join(inv for inv in self.base_urls if inv not in old_inventories) + if added: + added = f"+ {added}" + + removed = ', '.join(inv for inv in old_inventories if inv not in self.base_urls) + if removed: + removed = f"- {removed}" + + embed = discord.Embed( + title="Inventories refreshed", + description=f"```diff\n{added}\n{removed}```" if added or removed else "" + ) + await ctx.send(embed=embed) + + async def _fetch_inventory(self, inventory_url: str, config: SphinxConfiguration) -> Optional[dict]: + """Get and return inventory from `inventory_url`. If fetching fails, return None.""" + fetch_func = functools.partial(intersphinx.fetch_inventory, config, '', inventory_url) + for retry in range(1, FAILED_REQUEST_RETRY_AMOUNT+1): + try: + package = await self.bot.loop.run_in_executor(None, fetch_func) + except ConnectTimeout: + log.error( + f"Fetching of inventory {inventory_url} timed out," + f" trying again. ({retry}/{FAILED_REQUEST_RETRY_AMOUNT})" + ) + except ProtocolError: + log.error( + f"Connection lost while fetching inventory {inventory_url}," + f" trying again. ({retry}/{FAILED_REQUEST_RETRY_AMOUNT})" + ) + except HTTPError as e: + log.error(f"Fetching of inventory {inventory_url} failed with status code {e.response.status_code}.") + return None + except ConnectionError: + log.error(f"Couldn't establish connection to inventory {inventory_url}.") + return None + else: + return package + log.error(f"Fetching of inventory {inventory_url} failed.") + return None + + @staticmethod + def _match_end_tag(tag: Tag) -> bool: + """Matches `tag` if its class value is in `SEARCH_END_TAG_ATTRS` or the tag is table.""" + for attr in SEARCH_END_TAG_ATTRS: + if attr in tag.get("class", ()): + return True + + return tag.name == "table" + -def setup(bot: commands.Bot) -> None: - """Doc cog load.""" +def setup(bot: Bot) -> None: + """Load the Doc cog.""" bot.add_cog(Doc(bot)) - log.info("Cog loaded: Doc") diff --git a/bot/cogs/duck_pond.py b/bot/cogs/duck_pond.py new file mode 100644 index 000000000..345d2856c --- /dev/null +++ b/bot/cogs/duck_pond.py @@ -0,0 +1,182 @@ +import logging +from typing import Optional, Union + +import discord +from discord import Color, Embed, Member, Message, RawReactionActionEvent, User, errors +from discord.ext.commands import Cog + +from bot import constants +from bot.bot import Bot +from bot.utils.messages import send_attachments + +log = logging.getLogger(__name__) + + +class DuckPond(Cog): + """Relays messages to #duck-pond whenever a certain number of duck reactions have been achieved.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.webhook_id = constants.Webhooks.duck_pond + self.bot.loop.create_task(self.fetch_webhook()) + + async def fetch_webhook(self) -> None: + """Fetches the webhook object, so we can post to it.""" + await self.bot.wait_until_ready() + + try: + self.webhook = await self.bot.fetch_webhook(self.webhook_id) + except discord.HTTPException: + log.exception(f"Failed to fetch webhook with id `{self.webhook_id}`") + + @staticmethod + def is_staff(member: Union[User, Member]) -> bool: + """Check if a specific member or user is staff.""" + if hasattr(member, "roles"): + for role in member.roles: + if role.id in constants.STAFF_ROLES: + return True + return False + + async def has_green_checkmark(self, message: Message) -> bool: + """Check if the message has a green checkmark reaction.""" + for reaction in message.reactions: + if reaction.emoji == "✅": + async for user in reaction.users(): + if user == self.bot.user: + return True + return False + + async def send_webhook( + self, + content: Optional[str] = None, + username: Optional[str] = None, + avatar_url: Optional[str] = None, + embed: Optional[Embed] = None, + ) -> None: + """Send a webhook to the duck_pond channel.""" + try: + await self.webhook.send( + content=content, + username=username, + avatar_url=avatar_url, + embed=embed + ) + except discord.HTTPException: + log.exception("Failed to send a message to the Duck Pool webhook") + + async def count_ducks(self, message: Message) -> int: + """ + Count the number of ducks in the reactions of a specific message. + + Only counts ducks added by staff members. + """ + duck_count = 0 + duck_reactors = [] + + for reaction in message.reactions: + async for user in reaction.users(): + + # Is the user a staff member and not already counted as reactor? + if not self.is_staff(user) or user.id in duck_reactors: + continue + + # Is the emoji a duck? + if hasattr(reaction.emoji, "id"): + if reaction.emoji.id in constants.DuckPond.custom_emojis: + duck_count += 1 + duck_reactors.append(user.id) + elif isinstance(reaction.emoji, str): + if reaction.emoji == "🦆": + duck_count += 1 + duck_reactors.append(user.id) + return duck_count + + async def relay_message(self, message: Message) -> None: + """Relays the message's content and attachments to the duck pond channel.""" + clean_content = message.clean_content + + if clean_content: + await self.send_webhook( + content=message.clean_content, + username=message.author.display_name, + avatar_url=message.author.avatar_url + ) + + if message.attachments: + try: + await send_attachments(message, self.webhook) + except (errors.Forbidden, errors.NotFound): + e = Embed( + description=":x: **This message contained an attachment, but it could not be retrieved**", + color=Color.red() + ) + await self.send_webhook( + embed=e, + username=message.author.display_name, + avatar_url=message.author.avatar_url + ) + except discord.HTTPException: + log.exception(f"Failed to send an attachment to the webhook") + + await message.add_reaction("✅") + + @staticmethod + def _payload_has_duckpond_emoji(payload: RawReactionActionEvent) -> bool: + """Test if the RawReactionActionEvent payload contains a duckpond emoji.""" + if payload.emoji.is_custom_emoji(): + if payload.emoji.id in constants.DuckPond.custom_emojis: + return True + elif payload.emoji.name == "🦆": + return True + + return False + + @Cog.listener() + async def on_raw_reaction_add(self, payload: RawReactionActionEvent) -> None: + """ + Determine if a message should be sent to the duck pond. + + This will count the number of duck reactions on the message, and if this amount meets the + amount of ducks specified in the config under duck_pond/threshold, it will + send the message off to the duck pond. + """ + # Is the emoji in the reaction a duck? + if not self._payload_has_duckpond_emoji(payload): + return + + channel = discord.utils.get(self.bot.get_all_channels(), id=payload.channel_id) + message = await channel.fetch_message(payload.message_id) + member = discord.utils.get(message.guild.members, id=payload.user_id) + + # Is the member a human and a staff member? + if not self.is_staff(member) or member.bot: + return + + # Does the message already have a green checkmark? + if await self.has_green_checkmark(message): + return + + # Time to count our ducks! + duck_count = await self.count_ducks(message) + + # If we've got more than the required amount of ducks, send the message to the duck_pond. + if duck_count >= constants.DuckPond.threshold: + await self.relay_message(message) + + @Cog.listener() + async def on_raw_reaction_remove(self, payload: RawReactionActionEvent) -> None: + """Ensure that people don't remove the green checkmark from duck ponded messages.""" + channel = discord.utils.get(self.bot.get_all_channels(), id=payload.channel_id) + + # Prevent the green checkmark from being removed + if payload.emoji.name == "✅": + message = await channel.fetch_message(payload.message_id) + duck_count = await self.count_ducks(message) + if duck_count >= constants.DuckPond.threshold: + await message.add_reaction("✅") + + +def setup(bot: Bot) -> None: + """Load the DuckPond cog.""" + bot.add_cog(DuckPond(bot)) diff --git a/bot/cogs/error_handler.py b/bot/cogs/error_handler.py index 49411814c..52893b2ee 100644 --- a/bot/cogs/error_handler.py +++ b/bot/cogs/error_handler.py @@ -14,9 +14,10 @@ from discord.ext.commands import ( NoPrivateMessage, UserInputError, ) -from discord.ext.commands import Bot, Cog, Context +from discord.ext.commands import Cog, Context from bot.api import ResponseCodeError +from bot.bot import Bot from bot.constants import Channels from bot.decorators import InChannelCheckFailure @@ -75,6 +76,16 @@ class ErrorHandler(Cog): tags_get_command = self.bot.get_command("tags get") ctx.invoked_from_error_handler = True + log_msg = "Cancelling attempt to fall back to a tag due to failed checks." + try: + if not await tags_get_command.can_run(ctx): + log.debug(log_msg) + return + except CommandError as tag_error: + log.debug(log_msg) + await self.on_command_error(ctx, tag_error) + return + # Return to not raise the exception with contextlib.suppress(ResponseCodeError): await ctx.invoke(tags_get_command, tag_name=ctx.invoked_with) @@ -143,6 +154,5 @@ class ErrorHandler(Cog): def setup(bot: Bot) -> None: - """Error handler cog load.""" + """Load the ErrorHandler cog.""" bot.add_cog(ErrorHandler(bot)) - log.info("Cog loaded: Events") diff --git a/bot/cogs/eval.py b/bot/cogs/eval.py index 9ce854f2c..9c729f28a 100644 --- a/bot/cogs/eval.py +++ b/bot/cogs/eval.py @@ -9,8 +9,9 @@ from io import StringIO from typing import Any, Optional, Tuple import discord -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import Cog, Context, group +from bot.bot import Bot from bot.constants import Roles from bot.decorators import with_role from bot.interpreter import Interpreter @@ -148,7 +149,7 @@ class CodeEval(Cog): self.env.update(env) # Ignore this code, it works - _code = """ + code_ = """ async def func(): # (None,) -> Any try: with contextlib.redirect_stdout(self.stdout): @@ -162,7 +163,7 @@ async def func(): # (None,) -> Any """.format(textwrap.indent(code, ' ')) try: - exec(_code, self.env) # noqa: B102,S102 + exec(code_, self.env) # noqa: B102,S102 func = self.env['func'] res = await func() @@ -197,6 +198,5 @@ async def func(): # (None,) -> Any def setup(bot: Bot) -> None: - """Code eval cog load.""" + """Load the CodeEval cog.""" bot.add_cog(CodeEval(bot)) - log.info("Cog loaded: Eval") diff --git a/bot/cogs/extensions.py b/bot/cogs/extensions.py index bb66e0b8e..f16e79fb7 100644 --- a/bot/cogs/extensions.py +++ b/bot/cogs/extensions.py @@ -6,8 +6,9 @@ from pkgutil import iter_modules from discord import Colour, Embed from discord.ext import commands -from discord.ext.commands import Bot, Context, group +from discord.ext.commands import Context, group +from bot.bot import Bot from bot.constants import Emojis, MODERATION_ROLES, Roles, URLs from bot.pagination import LinePaginator from bot.utils.checks import with_role_check @@ -233,4 +234,3 @@ class Extensions(commands.Cog): def setup(bot: Bot) -> None: """Load the Extensions cog.""" bot.add_cog(Extensions(bot)) - log.info("Cog loaded: Extensions") diff --git a/bot/cogs/filtering.py b/bot/cogs/filtering.py index 4195783f1..74538542a 100644 --- a/bot/cogs/filtering.py +++ b/bot/cogs/filtering.py @@ -5,8 +5,9 @@ from typing import Optional, Union import discord.errors from dateutil.relativedelta import relativedelta from discord import Colour, DMChannel, Member, Message, TextChannel -from discord.ext.commands import Bot, Cog +from discord.ext.commands import Cog +from bot.bot import Bot from bot.cogs.moderation import ModLog from bot.constants import ( Channels, Colours, @@ -43,7 +44,7 @@ class Filtering(Cog): def __init__(self, bot: Bot): self.bot = bot - _staff_mistake_str = "If you believe this was a mistake, please let staff know!" + staff_mistake_str = "If you believe this was a mistake, please let staff know!" self.filters = { "filter_zalgo": { "enabled": Filter.filter_zalgo, @@ -53,7 +54,7 @@ class Filtering(Cog): "user_notification": Filter.notify_user_zalgo, "notification_msg": ( "Your post has been removed for abusing Unicode character rendering (aka Zalgo text). " - f"{_staff_mistake_str}" + f"{staff_mistake_str}" ) }, "filter_invites": { @@ -63,7 +64,7 @@ class Filtering(Cog): "content_only": True, "user_notification": Filter.notify_user_invites, "notification_msg": ( - f"Per Rule 6, your invite link has been removed. {_staff_mistake_str}\n\n" + f"Per Rule 6, your invite link has been removed. {staff_mistake_str}\n\n" r"Our server rules can be found here: <https://pythondiscord.com/pages/rules>" ) }, @@ -74,7 +75,7 @@ class Filtering(Cog): "content_only": True, "user_notification": Filter.notify_user_domains, "notification_msg": ( - f"Your URL has been removed because it matched a blacklisted domain. {_staff_mistake_str}" + f"Your URL has been removed because it matched a blacklisted domain. {staff_mistake_str}" ) }, "watch_rich_embeds": { @@ -370,6 +371,5 @@ class Filtering(Cog): def setup(bot: Bot) -> None: - """Filtering cog load.""" + """Load the Filtering cog.""" bot.add_cog(Filtering(bot)) - log.info("Cog loaded: Filtering") diff --git a/bot/cogs/free.py b/bot/cogs/free.py index 82285656b..49cab6172 100644 --- a/bot/cogs/free.py +++ b/bot/cogs/free.py @@ -3,8 +3,9 @@ from datetime import datetime from operator import itemgetter from discord import Colour, Embed, Member, utils -from discord.ext.commands import Bot, Cog, Context, command +from discord.ext.commands import Cog, Context, command +from bot.bot import Bot from bot.constants import Categories, Channels, Free, STAFF_ROLES from bot.decorators import redirect_output @@ -98,6 +99,5 @@ class Free(Cog): def setup(bot: Bot) -> None: - """Free cog load.""" + """Load the Free cog.""" bot.add_cog(Free()) - log.info("Cog loaded: Free") diff --git a/bot/cogs/help.py b/bot/cogs/help.py index 9607dbd8d..6385fa467 100644 --- a/bot/cogs/help.py +++ b/bot/cogs/help.py @@ -6,10 +6,11 @@ from typing import Union from discord import Colour, Embed, HTTPException, Message, Reaction, User from discord.ext import commands -from discord.ext.commands import Bot, CheckFailure, Cog as DiscordCog, Command, Context +from discord.ext.commands import CheckFailure, Cog as DiscordCog, Command, Context from fuzzywuzzy import fuzz, process from bot import constants +from bot.bot import Bot from bot.constants import Channels, STAFF_ROLES from bot.decorators import redirect_output from bot.pagination import ( diff --git a/bot/cogs/information.py b/bot/cogs/information.py index 530453600..1ede95ff4 100644 --- a/bot/cogs/information.py +++ b/bot/cogs/information.py @@ -9,10 +9,11 @@ from typing import Any, Mapping, Optional import discord from discord import CategoryChannel, Colour, Embed, Member, Role, TextChannel, VoiceChannel, utils from discord.ext import commands -from discord.ext.commands import Bot, BucketType, Cog, Context, command, group +from discord.ext.commands import BucketType, Cog, Context, command, group from discord.utils import escape_markdown from bot import constants +from bot.bot import Bot from bot.decorators import InChannelCheckFailure, in_channel, with_role from bot.utils.checks import cooldown_with_role_bypass, with_role_check from bot.utils.time import time_since @@ -391,6 +392,5 @@ class Information(Cog): def setup(bot: Bot) -> None: - """Information cog load.""" + """Load the Information cog.""" bot.add_cog(Information(bot)) - log.info("Cog loaded: Information") diff --git a/bot/cogs/jams.py b/bot/cogs/jams.py index be9d33e3e..985f28ce5 100644 --- a/bot/cogs/jams.py +++ b/bot/cogs/jams.py @@ -4,6 +4,7 @@ from discord import Member, PermissionOverwrite, utils from discord.ext import commands from more_itertools import unique_everseen +from bot.bot import Bot from bot.constants import Roles from bot.decorators import with_role @@ -13,7 +14,7 @@ log = logging.getLogger(__name__) class CodeJams(commands.Cog): """Manages the code-jam related parts of our server.""" - def __init__(self, bot: commands.Bot): + def __init__(self, bot: Bot): self.bot = bot @commands.command() @@ -108,7 +109,6 @@ class CodeJams(commands.Cog): ) -def setup(bot: commands.Bot) -> None: - """Code Jams cog load.""" +def setup(bot: Bot) -> None: + """Load the CodeJams cog.""" bot.add_cog(CodeJams(bot)) - log.info("Cog loaded: CodeJams") diff --git a/bot/cogs/logging.py b/bot/cogs/logging.py index c92b619ff..d1b7dcab3 100644 --- a/bot/cogs/logging.py +++ b/bot/cogs/logging.py @@ -1,8 +1,9 @@ import logging from discord import Embed -from discord.ext.commands import Bot, Cog +from discord.ext.commands import Cog +from bot.bot import Bot from bot.constants import Channels, DEBUG_MODE @@ -37,6 +38,5 @@ class Logging(Cog): def setup(bot: Bot) -> None: - """Logging cog load.""" + """Load the Logging cog.""" bot.add_cog(Logging(bot)) - log.info("Cog loaded: Logging") diff --git a/bot/cogs/moderation/__init__.py b/bot/cogs/moderation/__init__.py index 7383ed44e..5243cb92d 100644 --- a/bot/cogs/moderation/__init__.py +++ b/bot/cogs/moderation/__init__.py @@ -1,25 +1,13 @@ -import logging - -from discord.ext.commands import Bot - +from bot.bot import Bot from .infractions import Infractions from .management import ModManagement from .modlog import ModLog from .superstarify import Superstarify -log = logging.getLogger(__name__) - def setup(bot: Bot) -> None: - """Load the moderation extension (Infractions, ModManagement, ModLog, & Superstarify cogs).""" + """Load the Infractions, ModManagement, ModLog, and Superstarify cogs.""" bot.add_cog(Infractions(bot)) - log.info("Cog loaded: Infractions") - bot.add_cog(ModLog(bot)) - log.info("Cog loaded: ModLog") - bot.add_cog(ModManagement(bot)) - log.info("Cog loaded: ModManagement") - bot.add_cog(Superstarify(bot)) - log.info("Cog loaded: Superstarify") diff --git a/bot/cogs/moderation/infractions.py b/bot/cogs/moderation/infractions.py index 2713a1b68..3536a3d38 100644 --- a/bot/cogs/moderation/infractions.py +++ b/bot/cogs/moderation/infractions.py @@ -7,6 +7,7 @@ from discord.ext import commands from discord.ext.commands import Context, command from bot import constants +from bot.bot import Bot from bot.constants import Event from bot.decorators import respect_role_hierarchy from bot.utils.checks import with_role_check @@ -25,7 +26,7 @@ class Infractions(InfractionScheduler, commands.Cog): category = "Moderation" category_description = "Server moderation tools." - def __init__(self, bot: commands.Bot): + def __init__(self, bot: Bot): super().__init__(bot, supported_infractions={"ban", "kick", "mute", "note", "warning"}) self.category = "Moderation" @@ -208,8 +209,13 @@ class Infractions(InfractionScheduler, commands.Cog): self.mod_log.ignore(Event.member_update, user.id) - action = user.add_roles(self._muted_role, reason=reason) - await self.apply_infraction(ctx, infraction, user, action) + async def action() -> None: + await user.add_roles(self._muted_role, reason=reason) + + log.trace(f"Attempting to kick {user} from voice because they've been muted.") + await user.move_to(None, reason=reason) + + await self.apply_infraction(ctx, infraction, user, action()) @respect_role_hierarchy() async def apply_kick(self, ctx: Context, user: Member, reason: str, **kwargs) -> None: diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index 44a508436..9605d47b2 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -2,13 +2,15 @@ import asyncio import logging import textwrap import typing as t +from datetime import datetime import discord from discord.ext import commands from discord.ext.commands import Context from bot import constants -from bot.converters import InfractionSearchQuery +from bot.bot import Bot +from bot.converters import InfractionSearchQuery, allowed_strings from bot.pagination import LinePaginator from bot.utils import time from bot.utils.checks import in_channel_check, with_role_check @@ -21,21 +23,12 @@ log = logging.getLogger(__name__) UserConverter = t.Union[discord.User, utils.proxy_user] -def permanent_duration(expires_at: str) -> str: - """Only allow an expiration to be 'permanent' if it is a string.""" - expires_at = expires_at.lower() - if expires_at != "permanent": - raise commands.BadArgument - else: - return expires_at - - class ModManagement(commands.Cog): """Management of infractions.""" category = "Moderation" - def __init__(self, bot: commands.Bot): + def __init__(self, bot: Bot): self.bot = bot @property @@ -59,8 +52,8 @@ class ModManagement(commands.Cog): async def infraction_edit( self, ctx: Context, - infraction_id: int, - duration: t.Union[utils.Expiry, permanent_duration, None], + infraction_id: t.Union[int, allowed_strings("l", "last", "recent")], + duration: t.Union[utils.Expiry, allowed_strings("p", "permanent"), None], *, reason: str = None ) -> None: @@ -77,26 +70,45 @@ class ModManagement(commands.Cog): \u2003`M` - minutes∗ \u2003`s` - seconds - Use "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601 timestamp - can be provided for the duration. + Use "l", "last", or "recent" as the infraction ID to specify that the most recent infraction + authored by the command invoker should be edited. + + Use "p" or "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601 + timestamp can be provided for the duration. """ if duration is None and reason is None: # Unlike UserInputError, the error handler will show a specified message for BadArgument raise commands.BadArgument("Neither a new expiry nor a new reason was specified.") # Retrieve the previous infraction for its information. - old_infraction = await self.bot.api_client.get(f'bot/infractions/{infraction_id}') + if isinstance(infraction_id, str): + params = { + "actor__id": ctx.author.id, + "ordering": "-inserted_at" + } + infractions = await self.bot.api_client.get(f"bot/infractions", params=params) + + if infractions: + old_infraction = infractions[0] + infraction_id = old_infraction["id"] + else: + await ctx.send( + f":x: Couldn't find most recent infraction; you have never given an infraction." + ) + return + else: + old_infraction = await self.bot.api_client.get(f"bot/infractions/{infraction_id}") request_data = {} confirm_messages = [] log_text = "" - if duration == "permanent": + if isinstance(duration, str): request_data['expires_at'] = None confirm_messages.append("marked as permanent") elif duration is not None: request_data['expires_at'] = duration.isoformat() - expiry = duration.strftime(time.INFRACTION_FORMAT) + expiry = time.format_infraction_with_duration(request_data['expires_at']) confirm_messages.append(f"set to expire on {expiry}") else: confirm_messages.append("expiry unchanged") @@ -128,7 +140,8 @@ class ModManagement(commands.Cog): New expiry: {new_infraction['expires_at'] or "Permanent"} """.rstrip() - await ctx.send(f":ok_hand: Updated infraction: {' & '.join(confirm_messages)}") + changes = ' & '.join(confirm_messages) + await ctx.send(f":ok_hand: Updated infraction #{infraction_id}: {changes}") # Get information about the infraction's user user_id = new_infraction['user'] @@ -231,10 +244,17 @@ class ModManagement(commands.Cog): user_id = infraction["user"] hidden = infraction["hidden"] created = time.format_infraction(infraction["inserted_at"]) + + if active: + remaining = time.until_expiration(infraction["expires_at"]) or "Expired" + else: + remaining = "Inactive" + if infraction["expires_at"] is None: expires = "*Permanent*" else: - expires = time.format_infraction(infraction["expires_at"]) + date_from = datetime.strptime(created, time.INFRACTION_FORMAT) + expires = time.format_infraction_with_duration(infraction["expires_at"], date_from) lines = textwrap.dedent(f""" {"**===============**" if active else "==============="} @@ -245,6 +265,7 @@ class ModManagement(commands.Cog): Reason: {infraction["reason"] or "*None*"} Created: {created} Expires: {expires} + Remaining: {remaining} Actor: {actor.mention if actor else actor_id} ID: `{infraction["id"]}` {"**===============**" if active else "==============="} diff --git a/bot/cogs/moderation/modlog.py b/bot/cogs/moderation/modlog.py index 88f2b6c67..35ef6cbcc 100644 --- a/bot/cogs/moderation/modlog.py +++ b/bot/cogs/moderation/modlog.py @@ -1,4 +1,6 @@ import asyncio +import difflib +import itertools import logging import typing as t from datetime import datetime @@ -8,8 +10,9 @@ from dateutil.relativedelta import relativedelta from deepdiff import DeepDiff from discord import Colour from discord.abc import GuildChannel -from discord.ext.commands import Bot, Cog, Context +from discord.ext.commands import Cog, Context +from bot.bot import Bot from bot.constants import Channels, Colours, Emojis, Event, Guild as GuildConstant, Icons, URLs from bot.utils.time import humanize_delta from .utils import UserTypes @@ -618,80 +621,81 @@ class ModLog(Cog, name="ModLog"): ) @Cog.listener() - async def on_message_edit(self, before: discord.Message, after: discord.Message) -> None: + async def on_message_edit(self, msg_before: discord.Message, msg_after: discord.Message) -> None: """Log message edit event to message change log.""" if ( - not before.guild - or before.guild.id != GuildConstant.id - or before.channel.id in GuildConstant.ignored - or before.author.bot + not msg_before.guild + or msg_before.guild.id != GuildConstant.id + or msg_before.channel.id in GuildConstant.ignored + or msg_before.author.bot ): return - self._cached_edits.append(before.id) + self._cached_edits.append(msg_before.id) - if before.content == after.content: + if msg_before.content == msg_after.content: return - author = before.author - channel = before.channel + author = msg_before.author + channel = msg_before.channel + channel_name = f"{channel.category}/#{channel.name}" if channel.category else f"#{channel.name}" - if channel.category: - before_response = ( - f"**Author:** {author} (`{author.id}`)\n" - f"**Channel:** {channel.category}/#{channel.name} (`{channel.id}`)\n" - f"**Message ID:** `{before.id}`\n" - "\n" - f"{before.clean_content}" - ) - - after_response = ( - f"**Author:** {author} (`{author.id}`)\n" - f"**Channel:** {channel.category}/#{channel.name} (`{channel.id}`)\n" - f"**Message ID:** `{before.id}`\n" - "\n" - f"{after.clean_content}" - ) - else: - before_response = ( - f"**Author:** {author} (`{author.id}`)\n" - f"**Channel:** #{channel.name} (`{channel.id}`)\n" - f"**Message ID:** `{before.id}`\n" - "\n" - f"{before.clean_content}" - ) + # Getting the difference per words and group them by type - add, remove, same + # Note that this is intended grouping without sorting + diff = difflib.ndiff(msg_before.clean_content.split(), msg_after.clean_content.split()) + diff_groups = tuple( + (diff_type, tuple(s[2:] for s in diff_words)) + for diff_type, diff_words in itertools.groupby(diff, key=lambda s: s[0]) + ) - after_response = ( - f"**Author:** {author} (`{author.id}`)\n" - f"**Channel:** #{channel.name} (`{channel.id}`)\n" - f"**Message ID:** `{before.id}`\n" - "\n" - f"{after.clean_content}" - ) + content_before: t.List[str] = [] + content_after: t.List[str] = [] + + for index, (diff_type, words) in enumerate(diff_groups): + sub = ' '.join(words) + if diff_type == '-': + content_before.append(f"[{sub}](http://o.hi)") + elif diff_type == '+': + content_after.append(f"[{sub}](http://o.hi)") + elif diff_type == ' ': + if len(words) > 2: + sub = ( + f"{words[0] if index > 0 else ''}" + " ... " + f"{words[-1] if index < len(diff_groups) - 1 else ''}" + ) + content_before.append(sub) + content_after.append(sub) + + response = ( + f"**Author:** {author} (`{author.id}`)\n" + f"**Channel:** {channel_name} (`{channel.id}`)\n" + f"**Message ID:** `{msg_before.id}`\n" + "\n" + f"**Before**:\n{' '.join(content_before)}\n" + f"**After**:\n{' '.join(content_after)}\n" + "\n" + f"[Jump to message]({msg_after.jump_url})" + ) - if before.edited_at: + if msg_before.edited_at: # Message was previously edited, to assist with self-bot detection, use the edited_at # datetime as the baseline and create a human-readable delta between this edit event # and the last time the message was edited - timestamp = before.edited_at - delta = humanize_delta(relativedelta(after.edited_at, before.edited_at)) + timestamp = msg_before.edited_at + delta = humanize_delta(relativedelta(msg_after.edited_at, msg_before.edited_at)) footer = f"Last edited {delta} ago" else: # Message was not previously edited, use the created_at datetime as the baseline, no # delta calculation needed - timestamp = before.created_at + timestamp = msg_before.created_at footer = None await self.send_log_message( - Icons.message_edit, Colour.blurple(), "Message edited (Before)", before_response, + Icons.message_edit, Colour.blurple(), "Message edited", response, channel_id=Channels.message_log, timestamp_override=timestamp, footer=footer ) - await self.send_log_message( - Icons.message_edit, Colour.blurple(), "Message edited (After)", after_response, - channel_id=Channels.message_log, timestamp_override=after.edited_at - ) - @Cog.listener() async def on_raw_message_edit(self, event: discord.RawMessageUpdateEvent) -> None: """Log raw message edit event to message change log.""" @@ -718,39 +722,23 @@ class ModLog(Cog, name="ModLog"): author = message.author channel = message.channel + channel_name = f"{channel.category}/#{channel.name}" if channel.category else f"#{channel.name}" + + before_response = ( + f"**Author:** {author} (`{author.id}`)\n" + f"**Channel:** {channel_name} (`{channel.id}`)\n" + f"**Message ID:** `{message.id}`\n" + "\n" + "This message was not cached, so the message content cannot be displayed." + ) - if channel.category: - before_response = ( - f"**Author:** {author} (`{author.id}`)\n" - f"**Channel:** {channel.category}/#{channel.name} (`{channel.id}`)\n" - f"**Message ID:** `{message.id}`\n" - "\n" - "This message was not cached, so the message content cannot be displayed." - ) - - after_response = ( - f"**Author:** {author} (`{author.id}`)\n" - f"**Channel:** {channel.category}/#{channel.name} (`{channel.id}`)\n" - f"**Message ID:** `{message.id}`\n" - "\n" - f"{message.clean_content}" - ) - else: - before_response = ( - f"**Author:** {author} (`{author.id}`)\n" - f"**Channel:** #{channel.name} (`{channel.id}`)\n" - f"**Message ID:** `{message.id}`\n" - "\n" - "This message was not cached, so the message content cannot be displayed." - ) - - after_response = ( - f"**Author:** {author} (`{author.id}`)\n" - f"**Channel:** #{channel.name} (`{channel.id}`)\n" - f"**Message ID:** `{message.id}`\n" - "\n" - f"{message.clean_content}" - ) + after_response = ( + f"**Author:** {author} (`{author.id}`)\n" + f"**Channel:** {channel_name} (`{channel.id}`)\n" + f"**Message ID:** `{message.id}`\n" + "\n" + f"{message.clean_content}" + ) await self.send_log_message( Icons.message_edit, Colour.blurple(), "Message edited (Before)", diff --git a/bot/cogs/moderation/scheduler.py b/bot/cogs/moderation/scheduler.py index 7990df226..01e4b1fe7 100644 --- a/bot/cogs/moderation/scheduler.py +++ b/bot/cogs/moderation/scheduler.py @@ -7,10 +7,11 @@ from gettext import ngettext import dateutil.parser import discord -from discord.ext.commands import Bot, Context +from discord.ext.commands import Context from bot import constants from bot.api import ResponseCodeError +from bot.bot import Bot from bot.constants import Colours, STAFF_CHANNELS from bot.utils import time from bot.utils.scheduling import Scheduler @@ -39,6 +40,8 @@ class InfractionScheduler(Scheduler): """Schedule expiration for previous infractions.""" await self.bot.wait_until_ready() + log.trace(f"Rescheduling infractions for {self.__class__.__name__}.") + infractions = await self.bot.api_client.get( 'bot/infractions', params={'active': 'true'} @@ -59,6 +62,10 @@ class InfractionScheduler(Scheduler): # Mark as inactive if less than a minute remains. if delta < 60: + log.info( + "Infraction will be deactivated instead of re-applied " + "because less than 1 minute remains." + ) await self.deactivate_infraction(infraction) return @@ -77,10 +84,10 @@ class InfractionScheduler(Scheduler): infr_type = infraction["type"] icon = utils.INFRACTION_ICONS[infr_type][0] reason = infraction["reason"] - expiry = infraction["expires_at"] + expiry = time.format_infraction_with_duration(infraction["expires_at"]) + id_ = infraction['id'] - if expiry: - expiry = time.format_infraction(expiry) + log.trace(f"Applying {infr_type} infraction #{id_} to {user}.") # Default values for the confirmation message and mod log. confirm_msg = f":ok_hand: applied" @@ -107,14 +114,24 @@ class InfractionScheduler(Scheduler): dm_result = ":incoming_envelope: " dm_log_text = "\nDM: Sent" else: + dm_result = f"{constants.Emojis.failmail} " dm_log_text = "\nDM: **Failed**" - log_content = ctx.author.mention if infraction["actor"] == self.bot.user.id: + log.trace( + f"Infraction #{id_} actor is bot; including the reason in the confirmation message." + ) + end_msg = f" (reason: {infraction['reason']})" elif ctx.channel.id not in STAFF_CHANNELS: + log.trace( + f"Infraction #{id_} context is not in a staff channel; omitting infraction count." + ) + end_msg = "" else: + log.trace(f"Fetching total infraction count for {user}.") + infractions = await self.bot.api_client.get( "bot/infractions", params={"user__id": str(user.id)} @@ -124,24 +141,33 @@ class InfractionScheduler(Scheduler): # Execute the necessary actions to apply the infraction on Discord. if action_coro: + log.trace(f"Awaiting the infraction #{id_} application action coroutine.") try: await action_coro if expiry: # Schedule the expiration of the infraction. self.schedule_task(ctx.bot.loop, infraction["id"], infraction) - except discord.Forbidden: + except discord.HTTPException as e: # Accordingly display that applying the infraction failed. confirm_msg = f":x: failed to apply" expiry_msg = "" log_content = ctx.author.mention log_title = "failed to apply" + log_msg = f"Failed to apply {infr_type} infraction #{id_} to {user}" + if isinstance(e, discord.Forbidden): + log.warning(f"{log_msg}: bot lacks permissions.") + else: + log.exception(log_msg) + # Send a confirmation message to the invoking context. + log.trace(f"Sending infraction #{id_} confirmation message.") await ctx.send( f"{dm_result}{confirm_msg} **{infr_type}** to {user.mention}{expiry_msg}{end_msg}." ) # Send a log message to the mod log. + log.trace(f"Sending apply mod log for infraction #{id_}.") await self.mod_log.send_log_message( icon_url=icon, colour=Colours.soft_red, @@ -157,9 +183,14 @@ class InfractionScheduler(Scheduler): footer=f"ID {infraction['id']}" ) + log.info(f"Applied {infr_type} infraction #{id_} to {user}.") + async def pardon_infraction(self, ctx: Context, infr_type: str, user: MemberObject) -> None: """Prematurely end an infraction for a user and log the action in the mod log.""" + log.trace(f"Pardoning {infr_type} infraction for {user}.") + # Check the current active infraction + log.trace(f"Fetching active {infr_type} infractions for {user}.") response = await self.bot.api_client.get( 'bot/infractions', params={ @@ -170,6 +201,7 @@ class InfractionScheduler(Scheduler): ) if not response: + log.debug(f"No active {infr_type} infraction found for {user}.") await ctx.send(f":x: There's no active {infr_type} infraction for user {user.mention}.") return @@ -179,12 +211,16 @@ class InfractionScheduler(Scheduler): log_text["Member"] = f"{user.mention}(`{user.id}`)" log_text["Actor"] = str(ctx.message.author) log_content = None - footer = f"ID: {response[0]['id']}" + id_ = response[0]['id'] + footer = f"ID: {id_}" # If multiple active infractions were found, mark them as inactive in the database # and cancel their expiration tasks. if len(response) > 1: - log.warning(f"Found more than one active {infr_type} infraction for user {user.id}") + log.warning( + f"Found more than one active {infr_type} infraction for user {user.id}; " + "deactivating the extra active infractions too." + ) footer = f"Infraction IDs: {', '.join(str(infr['id']) for infr in response)}" @@ -198,15 +234,15 @@ class InfractionScheduler(Scheduler): # 1. Discord cannot store multiple active bans or assign multiples of the same role # 2. It would send a pardon DM for each active infraction, which is redundant for infraction in response[1:]: - _id = infraction['id'] + id_ = infraction['id'] try: # Mark infraction as inactive in the database. await self.bot.api_client.patch( - f"bot/infractions/{_id}", + f"bot/infractions/{id_}", json={"active": False} ) except ResponseCodeError: - log.exception(f"Failed to deactivate infraction #{_id} ({infr_type})") + log.exception(f"Failed to deactivate infraction #{id_} ({infr_type})") # This is simpler and cleaner than trying to concatenate all the errors. log_text["Failure"] = "See bot's logs for details." @@ -219,19 +255,23 @@ class InfractionScheduler(Scheduler): if log_text.get("DM") == "Sent": dm_emoji = ":incoming_envelope: " elif "DM" in log_text: - # Mention the actor because the DM failed to send. - log_content = ctx.author.mention + dm_emoji = f"{constants.Emojis.failmail} " # Accordingly display whether the pardon failed. if "Failure" in log_text: confirm_msg = ":x: failed to pardon" log_title = "pardon failed" log_content = ctx.author.mention + + log.warning(f"Failed to pardon {infr_type} infraction #{id_} for {user}.") else: confirm_msg = f":ok_hand: pardoned" log_title = "pardoned" + log.info(f"Pardoned {infr_type} infraction #{id_} for {user}.") + # Send a confirmation message to the invoking context. + log.trace(f"Sending infraction #{id_} pardon confirmation message.") await ctx.send( f"{dm_emoji}{confirm_msg} infraction **{infr_type}** for {user.mention}. " f"{log_text.get('Failure', '')}" @@ -265,10 +305,10 @@ class InfractionScheduler(Scheduler): guild = self.bot.get_guild(constants.Guild.id) mod_role = guild.get_role(constants.Roles.moderator) user_id = infraction["user"] - _type = infraction["type"] - _id = infraction["id"] + type_ = infraction["type"] + id_ = infraction["id"] - log.debug(f"Marking infraction #{_id} as inactive (expired).") + log.info(f"Marking infraction #{id_} as inactive (expired).") log_content = None log_text = { @@ -278,24 +318,28 @@ class InfractionScheduler(Scheduler): } try: + log.trace("Awaiting the pardon action coroutine.") returned_log = await self._pardon_action(infraction) + if returned_log is not None: log_text = {**log_text, **returned_log} # Merge the logs together else: raise ValueError( - f"Attempted to deactivate an unsupported infraction #{_id} ({_type})!" + f"Attempted to deactivate an unsupported infraction #{id_} ({type_})!" ) except discord.Forbidden: - log.warning(f"Failed to deactivate infraction #{_id} ({_type}): bot lacks permissions") + log.warning(f"Failed to deactivate infraction #{id_} ({type_}): bot lacks permissions.") log_text["Failure"] = f"The bot lacks permissions to do this (role hierarchy?)" log_content = mod_role.mention except discord.HTTPException as e: - log.exception(f"Failed to deactivate infraction #{_id} ({_type})") - log_text["Failure"] = f"HTTPException with code {e.code}." + log.exception(f"Failed to deactivate infraction #{id_} ({type_})") + log_text["Failure"] = f"HTTPException with status {e.status} and code {e.code}." log_content = mod_role.mention # Check if the user is currently being watched by Big Brother. try: + log.trace(f"Determining if user {user_id} is currently being watched by Big Brother.") + active_watch = await self.bot.api_client.get( "bot/infractions", params={ @@ -312,12 +356,13 @@ class InfractionScheduler(Scheduler): try: # Mark infraction as inactive in the database. + log.trace(f"Marking infraction #{id_} as inactive in the database.") await self.bot.api_client.patch( - f"bot/infractions/{_id}", + f"bot/infractions/{id_}", json={"active": False} ) except ResponseCodeError as e: - log.exception(f"Failed to deactivate infraction #{_id} ({_type})") + log.exception(f"Failed to deactivate infraction #{id_} ({type_})") log_line = f"API request failed with code {e.status}." log_content = mod_role.mention @@ -335,12 +380,13 @@ class InfractionScheduler(Scheduler): if send_log: log_title = f"expiration failed" if "Failure" in log_text else "expired" + log.trace(f"Sending deactivation mod log for infraction #{id_}.") await self.mod_log.send_log_message( - icon_url=utils.INFRACTION_ICONS[_type][1], + icon_url=utils.INFRACTION_ICONS[type_][1], colour=Colours.soft_green, - title=f"Infraction {log_title}: {_type}", + title=f"Infraction {log_title}: {type_}", text="\n".join(f"{k}: {v}" for k, v in log_text.items()), - footer=f"ID: {_id}", + footer=f"ID: {id_}", content=log_content, ) diff --git a/bot/cogs/moderation/superstarify.py b/bot/cogs/moderation/superstarify.py index c66222e5a..7631d9bbe 100644 --- a/bot/cogs/moderation/superstarify.py +++ b/bot/cogs/moderation/superstarify.py @@ -6,9 +6,10 @@ import typing as t from pathlib import Path from discord import Colour, Embed, Member -from discord.ext.commands import Bot, Cog, Context, command +from discord.ext.commands import Cog, Context, command from bot import constants +from bot.bot import Bot from bot.utils.checks import with_role_check from bot.utils.time import format_infraction from . import utils @@ -34,8 +35,8 @@ class Superstarify(InfractionScheduler, Cog): return # User didn't change their nickname. Abort! log.trace( - f"{before.display_name} is trying to change their nickname to {after.display_name}. " - "Checking if the user is in superstar-prison..." + f"{before} ({before.display_name}) is trying to change their nickname to " + f"{after.display_name}. Checking if the user is in superstar-prison..." ) active_superstarifies = await self.bot.api_client.get( @@ -48,6 +49,7 @@ class Superstarify(InfractionScheduler, Cog): ) if not active_superstarifies: + log.trace(f"{before} has no active superstar infractions.") return infraction = active_superstarifies[0] @@ -132,15 +134,17 @@ class Superstarify(InfractionScheduler, Cog): # Post the infraction to the API reason = reason or f"old nick: {member.display_name}" infraction = await utils.post_infraction(ctx, member, "superstar", reason, duration) + id_ = infraction["id"] old_nick = member.display_name - forced_nick = self.get_nick(infraction["id"], member.id) + forced_nick = self.get_nick(id_, member.id) expiry_str = format_infraction(infraction["expires_at"]) # Apply the infraction and schedule the expiration task. + log.debug(f"Changing nickname of {member} to {forced_nick}.") self.mod_log.ignore(constants.Event.member_update, member.id) await member.edit(nick=forced_nick, reason=reason) - self.schedule_task(ctx.bot.loop, infraction["id"], infraction) + self.schedule_task(ctx.bot.loop, id_, infraction) # Send a DM to the user to notify them of their new infraction. await utils.notify_infraction( @@ -152,6 +156,7 @@ class Superstarify(InfractionScheduler, Cog): ) # Send an embed with the infraction information to the invoking context. + log.trace(f"Sending superstar #{id_} embed.") embed = Embed( title="Congratulations!", colour=constants.Colours.soft_orange, @@ -167,6 +172,7 @@ class Superstarify(InfractionScheduler, Cog): await ctx.send(embed=embed) # Log to the mod log channel. + log.trace(f"Sending apply mod log for superstar #{id_}.") await self.mod_log.send_log_message( icon_url=utils.INFRACTION_ICONS["superstar"][0], colour=Colour.gold(), @@ -180,7 +186,7 @@ class Superstarify(InfractionScheduler, Cog): Old nickname: `{old_nick}` New nickname: `{forced_nick}` """), - footer=f"ID {infraction['id']}" + footer=f"ID {id_}" ) @command(name="unsuperstarify", aliases=("release_nick", "unstar")) @@ -198,6 +204,10 @@ class Superstarify(InfractionScheduler, Cog): # Don't bother sending a notification if the user left the guild. if not user: + log.debug( + "User left the guild and therefore won't be notified about superstar " + f"{infraction['id']} pardon." + ) return {} # DM the user about the expiration. @@ -216,6 +226,8 @@ class Superstarify(InfractionScheduler, Cog): @staticmethod def get_nick(infraction_id: int, member_id: int) -> str: """Randomly select a nickname from the Superstarify nickname list.""" + log.trace(f"Choosing a random nickname for superstar #{infraction_id}.") + rng = random.Random(str(infraction_id) + str(member_id)) return rng.choice(STAR_NAMES) diff --git a/bot/cogs/moderation/utils.py b/bot/cogs/moderation/utils.py index 9179c0afb..325b9567a 100644 --- a/bot/cogs/moderation/utils.py +++ b/bot/cogs/moderation/utils.py @@ -37,6 +37,8 @@ def proxy_user(user_id: str) -> discord.Object: Used when a Member or User object cannot be resolved. """ + log.trace(f"Attempting to create a proxy user for the user id {user_id}.") + try: user_id = int(user_id) except ValueError: @@ -59,6 +61,8 @@ async def post_infraction( active: bool = True, ) -> t.Optional[dict]: """Posts an infraction to the API.""" + log.trace(f"Posting {infr_type} infraction for {user} to the API.") + payload = { "actor": ctx.message.author.id, "hidden": hidden, @@ -92,6 +96,8 @@ async def post_infraction( async def has_active_infraction(ctx: Context, user: MemberObject, infr_type: str) -> bool: """Checks if a user already has an active infraction of the given type.""" + log.trace(f"Checking if {user} has active infractions of type {infr_type}.") + active_infractions = await ctx.bot.api_client.get( 'bot/infractions', params={ @@ -101,12 +107,14 @@ async def has_active_infraction(ctx: Context, user: MemberObject, infr_type: str } ) if active_infractions: + log.trace(f"{user} has active infractions of type {infr_type}.") await ctx.send( f":x: According to my records, this user already has a {infr_type} infraction. " f"See infraction **#{active_infractions[0]['id']}**." ) return True else: + log.trace(f"{user} does not have active infractions of type {infr_type}.") return False @@ -118,6 +126,8 @@ async def notify_infraction( icon_url: str = Icons.token_removed ) -> bool: """DM a user about their new infraction and return True if the DM is successful.""" + log.trace(f"Sending {user} a DM about their {infr_type} infraction.") + embed = discord.Embed( description=textwrap.dedent(f""" **Type:** {infr_type.capitalize()} @@ -146,6 +156,8 @@ async def notify_pardon( icon_url: str = Icons.user_verified ) -> bool: """DM a user about their pardoned infraction and return True if the DM is successful.""" + log.trace(f"Sending {user} a DM about their pardoned infraction.") + embed = discord.Embed( description=content, colour=Colours.soft_green diff --git a/bot/cogs/off_topic_names.py b/bot/cogs/off_topic_names.py index 78792240f..bf777ea5a 100644 --- a/bot/cogs/off_topic_names.py +++ b/bot/cogs/off_topic_names.py @@ -4,9 +4,10 @@ import logging from datetime import datetime, timedelta from discord import Colour, Embed -from discord.ext.commands import BadArgument, Bot, Cog, Context, Converter, group +from discord.ext.commands import BadArgument, Cog, Context, Converter, group from bot.api import ResponseCodeError +from bot.bot import Bot from bot.constants import Channels, MODERATION_ROLES from bot.decorators import with_role from bot.pagination import LinePaginator @@ -184,6 +185,5 @@ class OffTopicNames(Cog): def setup(bot: Bot) -> None: - """Off topic names cog load.""" + """Load the OffTopicNames cog.""" bot.add_cog(OffTopicNames(bot)) - log.info("Cog loaded: OffTopicNames") diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index 0d06e9c26..aa487f18e 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -2,13 +2,16 @@ import asyncio import logging import random import textwrap +from collections import namedtuple from datetime import datetime, timedelta from typing import List +from aiohttp import BasicAuth, ClientError from discord import Colour, Embed, TextChannel -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import Cog, Context, group from discord.ext.tasks import loop +from bot.bot import Bot from bot.constants import Channels, ERROR_REPLIES, Emojis, Reddit as RedditConfig, STAFF_ROLES, Webhooks from bot.converters import Subreddit from bot.decorators import with_role @@ -16,25 +19,32 @@ from bot.pagination import LinePaginator log = logging.getLogger(__name__) +AccessToken = namedtuple("AccessToken", ["token", "expires_at"]) + class Reddit(Cog): """Track subreddit posts and show detailed statistics about them.""" - HEADERS = {"User-Agent": "Discord Bot: PythonDiscord (https://pythondiscord.com/)"} + HEADERS = {"User-Agent": "python3:python-discord/bot:1.0.0 (by /u/PythonDiscord)"} URL = "https://www.reddit.com" - MAX_FETCH_RETRIES = 3 + OAUTH_URL = "https://oauth.reddit.com" + MAX_RETRIES = 3 def __init__(self, bot: Bot): self.bot = bot - self.webhook = None # set in on_ready - bot.loop.create_task(self.init_reddit_ready()) + self.webhook = None + self.access_token = None + self.client_auth = BasicAuth(RedditConfig.client_id, RedditConfig.secret) + bot.loop.create_task(self.init_reddit_ready()) self.auto_poster_loop.start() def cog_unload(self) -> None: - """Stops the loops when the cog is unloaded.""" + """Stop the loop task and revoke the access token when the cog is unloaded.""" self.auto_poster_loop.cancel() + if self.access_token.expires_at < datetime.utcnow(): + self.revoke_access_token() async def init_reddit_ready(self) -> None: """Sets the reddit webhook when the cog is loaded.""" @@ -47,20 +57,82 @@ class Reddit(Cog): """Get the #reddit channel object from the bot's cache.""" return self.bot.get_channel(Channels.reddit) + async def get_access_token(self) -> None: + """ + Get a Reddit API OAuth2 access token and assign it to self.access_token. + + A token is valid for 1 hour. There will be MAX_RETRIES to get a token, after which the cog + will be unloaded and a ClientError raised if retrieval was still unsuccessful. + """ + for i in range(1, self.MAX_RETRIES + 1): + response = await self.bot.http_session.post( + url=f"{self.URL}/api/v1/access_token", + headers=self.HEADERS, + auth=self.client_auth, + data={ + "grant_type": "client_credentials", + "duration": "temporary" + } + ) + + if response.status == 200 and response.content_type == "application/json": + content = await response.json() + expiration = int(content["expires_in"]) - 60 # Subtract 1 minute for leeway. + self.access_token = AccessToken( + token=content["access_token"], + expires_at=datetime.utcnow() + timedelta(seconds=expiration) + ) + + log.debug(f"New token acquired; expires on {self.access_token.expires_at}") + return + else: + log.debug( + f"Failed to get an access token: " + f"status {response.status} & content type {response.content_type}; " + f"retrying ({i}/{self.MAX_RETRIES})" + ) + + await asyncio.sleep(3) + + self.bot.remove_cog(self.qualified_name) + raise ClientError("Authentication with the Reddit API failed. Unloading the cog.") + + async def revoke_access_token(self) -> None: + """ + Revoke the OAuth2 access token for the Reddit API. + + For security reasons, it's good practice to revoke the token when it's no longer being used. + """ + response = await self.bot.http_session.post( + url=f"{self.URL}/api/v1/revoke_token", + headers=self.HEADERS, + auth=self.client_auth, + data={ + "token": self.access_token.token, + "token_type_hint": "access_token" + } + ) + + if response.status == 204 and response.content_type == "application/json": + self.access_token = None + else: + log.warning(f"Unable to revoke access token: status {response.status}.") + async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]: """A helper method to fetch a certain amount of Reddit posts at a given route.""" # Reddit's JSON responses only provide 25 posts at most. if not 25 >= amount > 0: raise ValueError("Invalid amount of subreddit posts requested.") - if params is None: - params = {} + # Renew the token if necessary. + if not self.access_token or self.access_token.expires_at < datetime.utcnow(): + await self.get_access_token() - url = f"{self.URL}/{route}.json" - for _ in range(self.MAX_FETCH_RETRIES): + url = f"{self.OAUTH_URL}/{route}" + for _ in range(self.MAX_RETRIES): response = await self.bot.http_session.get( url=url, - headers=self.HEADERS, + headers={**self.HEADERS, "Authorization": f"bearer {self.access_token.token}"}, params=params ) if response.status == 200 and response.content_type == 'application/json': @@ -217,6 +289,5 @@ class Reddit(Cog): def setup(bot: Bot) -> None: - """Reddit cog load.""" + """Load the Reddit cog.""" bot.add_cog(Reddit(bot)) - log.info("Cog loaded: Reddit") diff --git a/bot/cogs/reminders.py b/bot/cogs/reminders.py index 81990704b..45bf9a8f4 100644 --- a/bot/cogs/reminders.py +++ b/bot/cogs/reminders.py @@ -8,8 +8,9 @@ from typing import Optional from dateutil.relativedelta import relativedelta from discord import Colour, Embed, Message -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import Cog, Context, group +from bot.bot import Bot from bot.constants import Channels, Icons, NEGATIVE_REPLIES, POSITIVE_REPLIES, STAFF_ROLES from bot.converters import Duration from bot.pagination import LinePaginator @@ -290,6 +291,5 @@ class Reminders(Scheduler, Cog): def setup(bot: Bot) -> None: - """Reminders cog load.""" + """Load the Reminders cog.""" bot.add_cog(Reminders(bot)) - log.info("Cog loaded: Reminders") diff --git a/bot/cogs/security.py b/bot/cogs/security.py index 316b33d6b..c680c5e27 100644 --- a/bot/cogs/security.py +++ b/bot/cogs/security.py @@ -1,6 +1,8 @@ import logging -from discord.ext.commands import Bot, Cog, Context, NoPrivateMessage +from discord.ext.commands import Cog, Context, NoPrivateMessage + +from bot.bot import Bot log = logging.getLogger(__name__) @@ -25,6 +27,5 @@ class Security(Cog): def setup(bot: Bot) -> None: - """Security cog load.""" + """Load the Security cog.""" bot.add_cog(Security(bot)) - log.info("Cog loaded: Security") diff --git a/bot/cogs/site.py b/bot/cogs/site.py index 683613788..2ea8c7a2e 100644 --- a/bot/cogs/site.py +++ b/bot/cogs/site.py @@ -1,8 +1,9 @@ import logging from discord import Colour, Embed -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import Cog, Context, group +from bot.bot import Bot from bot.constants import URLs from bot.pagination import LinePaginator @@ -138,6 +139,5 @@ class Site(Cog): def setup(bot: Bot) -> None: - """Site cog load.""" + """Load the Site cog.""" bot.add_cog(Site(bot)) - log.info("Cog loaded: Site") diff --git a/bot/cogs/snekbox.py b/bot/cogs/snekbox.py index 362968bd0..da33e27b2 100644 --- a/bot/cogs/snekbox.py +++ b/bot/cogs/snekbox.py @@ -5,8 +5,9 @@ import textwrap from signal import Signals from typing import Optional, Tuple -from discord.ext.commands import Bot, Cog, Context, command, guild_only +from discord.ext.commands import Cog, Context, command, guild_only +from bot.bot import Bot from bot.constants import Channels, Roles, URLs from bot.decorators import in_channel from bot.utils.messages import wait_for_deletion @@ -176,7 +177,7 @@ class Snekbox(Cog): @command(name="eval", aliases=("e",)) @guild_only() - @in_channel(Channels.bot, bypass_roles=EVAL_ROLES) + @in_channel(Channels.bot, hidden_channels=(Channels.esoteric,), bypass_roles=EVAL_ROLES) async def eval_command(self, ctx: Context, *, code: str = None) -> None: """ Run Python code and get the results. @@ -227,6 +228,5 @@ class Snekbox(Cog): def setup(bot: Bot) -> None: - """Snekbox cog load.""" + """Load the Snekbox cog.""" bot.add_cog(Snekbox(bot)) - log.info("Cog loaded: Snekbox") diff --git a/bot/cogs/sync/__init__.py b/bot/cogs/sync/__init__.py index d4565f848..fe7df4e9b 100644 --- a/bot/cogs/sync/__init__.py +++ b/bot/cogs/sync/__init__.py @@ -1,13 +1,7 @@ -import logging - -from discord.ext.commands import Bot - +from bot.bot import Bot from .cog import Sync -log = logging.getLogger(__name__) - def setup(bot: Bot) -> None: - """Sync cog load.""" + """Load the Sync cog.""" bot.add_cog(Sync(bot)) - log.info("Cog loaded: Sync") diff --git a/bot/cogs/sync/cog.py b/bot/cogs/sync/cog.py index aaa581f96..90d4c40fe 100644 --- a/bot/cogs/sync/cog.py +++ b/bot/cogs/sync/cog.py @@ -3,10 +3,11 @@ from typing import Callable, Iterable from discord import Guild, Member, Role from discord.ext import commands -from discord.ext.commands import Bot, Cog, Context +from discord.ext.commands import Cog, Context from bot import constants from bot.api import ResponseCodeError +from bot.bot import Bot from bot.cogs.sync import syncers log = logging.getLogger(__name__) diff --git a/bot/cogs/sync/syncers.py b/bot/cogs/sync/syncers.py index 2cc5a66e1..14cf51383 100644 --- a/bot/cogs/sync/syncers.py +++ b/bot/cogs/sync/syncers.py @@ -2,7 +2,8 @@ from collections import namedtuple from typing import Dict, Set, Tuple from discord import Guild -from discord.ext.commands import Bot + +from bot.bot import Bot # These objects are declared as namedtuples because tuples are hashable, # something that we make use of when diffing site roles against guild roles. @@ -52,7 +53,7 @@ async def sync_roles(bot: Bot, guild: Guild) -> Tuple[int, int, int]: Synchronize roles found on the given `guild` with the ones on the API. Arguments: - bot (discord.ext.commands.Bot): + bot (bot.bot.Bot): The bot instance that we're running with. guild (discord.Guild): @@ -169,7 +170,7 @@ async def sync_users(bot: Bot, guild: Guild) -> Tuple[int, int, None]: Synchronize users found in the given `guild` with the ones in the API. Arguments: - bot (discord.ext.commands.Bot): + bot (bot.bot.Bot): The bot instance that we're running with. guild (discord.Guild): diff --git a/bot/cogs/tags.py b/bot/cogs/tags.py index cd70e783a..970301013 100644 --- a/bot/cogs/tags.py +++ b/bot/cogs/tags.py @@ -2,8 +2,9 @@ import logging import time from discord import Colour, Embed -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import Cog, Context, group +from bot.bot import Bot from bot.constants import Channels, Cooldowns, MODERATION_ROLES, Roles from bot.converters import TagContentConverter, TagNameConverter from bot.decorators import with_role @@ -160,6 +161,5 @@ class Tags(Cog): def setup(bot: Bot) -> None: - """Tags cog load.""" + """Load the Tags cog.""" bot.add_cog(Tags(bot)) - log.info("Cog loaded: Tags") diff --git a/bot/cogs/token_remover.py b/bot/cogs/token_remover.py index 5a0d20e57..82c01ae96 100644 --- a/bot/cogs/token_remover.py +++ b/bot/cogs/token_remover.py @@ -6,9 +6,10 @@ import struct from datetime import datetime from discord import Colour, Message -from discord.ext.commands import Bot, Cog +from discord.ext.commands import Cog from discord.utils import snowflake_time +from bot.bot import Bot from bot.cogs.moderation import ModLog from bot.constants import Channels, Colours, Event, Icons @@ -52,39 +53,60 @@ class TokenRemover(Cog): See: https://discordapp.com/developers/docs/reference#snowflakes """ + if self.is_token_in_message(msg): + await self.take_action(msg) + + @Cog.listener() + async def on_message_edit(self, before: Message, after: Message) -> None: + """ + Check each edit for a string that matches Discord's token pattern. + + See: https://discordapp.com/developers/docs/reference#snowflakes + """ + if self.is_token_in_message(after): + await self.take_action(after) + + async def take_action(self, msg: Message) -> None: + """Remove the `msg` containing a token an send a mod_log message.""" + user_id, creation_timestamp, hmac = TOKEN_RE.search(msg.content).group(0).split('.') + self.mod_log.ignore(Event.message_delete, msg.id) + await msg.delete() + await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention)) + + message = ( + "Censored a seemingly valid token sent by " + f"{msg.author} (`{msg.author.id}`) in {msg.channel.mention}, token was " + f"`{user_id}.{creation_timestamp}.{'x' * len(hmac)}`" + ) + log.debug(message) + + # Send pretty mod log embed to mod-alerts + await self.mod_log.send_log_message( + icon_url=Icons.token_removed, + colour=Colour(Colours.soft_red), + title="Token removed!", + text=message, + thumbnail=msg.author.avatar_url_as(static_format="png"), + channel_id=Channels.mod_alerts, + ) + + @classmethod + def is_token_in_message(cls, msg: Message) -> bool: + """Check if `msg` contains a seemly valid token.""" if msg.author.bot: - return + return False maybe_match = TOKEN_RE.search(msg.content) if maybe_match is None: - return + return False try: user_id, creation_timestamp, hmac = maybe_match.group(0).split('.') except ValueError: - return - - if self.is_valid_user_id(user_id) and self.is_valid_timestamp(creation_timestamp): - self.mod_log.ignore(Event.message_delete, msg.id) - await msg.delete() - await msg.channel.send(DELETION_MESSAGE_TEMPLATE.format(mention=msg.author.mention)) - - message = ( - "Censored a seemingly valid token sent by " - f"{msg.author} (`{msg.author.id}`) in {msg.channel.mention}, token was " - f"`{user_id}.{creation_timestamp}.{'x' * len(hmac)}`" - ) - log.debug(message) - - # Send pretty mod log embed to mod-alerts - await self.mod_log.send_log_message( - icon_url=Icons.token_removed, - colour=Colour(Colours.soft_red), - title="Token removed!", - text=message, - thumbnail=msg.author.avatar_url_as(static_format="png"), - channel_id=Channels.mod_alerts, - ) + return False + + if cls.is_valid_user_id(user_id) and cls.is_valid_timestamp(creation_timestamp): + return True @staticmethod def is_valid_user_id(b64_content: str) -> bool: @@ -119,6 +141,5 @@ class TokenRemover(Cog): def setup(bot: Bot) -> None: - """Token Remover cog load.""" + """Load the TokenRemover cog.""" bot.add_cog(TokenRemover(bot)) - log.info("Cog loaded: TokenRemover") diff --git a/bot/cogs/utils.py b/bot/cogs/utils.py index 793fe4c1a..47a59db66 100644 --- a/bot/cogs/utils.py +++ b/bot/cogs/utils.py @@ -8,8 +8,9 @@ from typing import Tuple from dateutil import relativedelta from discord import Colour, Embed, Message, Role -from discord.ext.commands import Bot, Cog, Context, command +from discord.ext.commands import Cog, Context, command +from bot.bot import Bot from bot.constants import Channels, MODERATION_ROLES, Mention, STAFF_ROLES from bot.decorators import in_channel, with_role from bot.utils.time import humanize_delta @@ -176,6 +177,5 @@ class Utils(Cog): def setup(bot: Bot) -> None: - """Utils cog load.""" + """Load the Utils cog.""" bot.add_cog(Utils(bot)) - log.info("Cog loaded: Utils") diff --git a/bot/cogs/verification.py b/bot/cogs/verification.py index b5e8d4357..988e0d49a 100644 --- a/bot/cogs/verification.py +++ b/bot/cogs/verification.py @@ -3,15 +3,17 @@ from datetime import datetime from discord import Colour, Message, NotFound, Object from discord.ext import tasks -from discord.ext.commands import Bot, Cog, Context, command +from discord.ext.commands import Cog, Context, command +from bot.bot import Bot from bot.cogs.moderation import ModLog from bot.constants import ( Bot as BotConfig, Channels, Colours, Event, - Filter, Icons, Roles + Filter, Icons, MODERATION_ROLES, Roles ) from bot.decorators import InChannelCheckFailure, in_channel, without_role +from bot.utils.checks import without_role_check log = logging.getLogger(__name__) @@ -37,6 +39,7 @@ PERIODIC_PING = ( f"@everyone To verify that you have read our rules, please type `{BotConfig.prefix}accept`." f" If you encounter any problems during the verification process, ping the <@&{Roles.admin}> role in this channel." ) +BOT_MESSAGE_DELETE_DELAY = 10 class Verification(Cog): @@ -54,12 +57,16 @@ class Verification(Cog): @Cog.listener() async def on_message(self, message: Message) -> None: """Check new message event for messages to the checkpoint channel & process.""" - if message.author.bot: - return # They're a bot, ignore - if message.channel.id != Channels.verification: return # Only listen for #checkpoint messages + if message.author.bot: + # They're a bot, delete their message after the delay. + # But not the periodic ping; we like that one. + if message.content != PERIODIC_PING: + await message.delete(delay=BOT_MESSAGE_DELETE_DELAY) + return + # if a user mentions a role or guild member # alert the mods in mod-alerts channel if message.mentions or message.role_mentions: @@ -189,7 +196,7 @@ class Verification(Cog): @staticmethod def bot_check(ctx: Context) -> bool: """Block any command within the verification channel that is not !accept.""" - if ctx.channel.id == Channels.verification: + if ctx.channel.id == Channels.verification and without_role_check(ctx, *MODERATION_ROLES): return ctx.command.name == "accept" else: return True @@ -224,6 +231,5 @@ class Verification(Cog): def setup(bot: Bot) -> None: - """Verification cog load.""" + """Load the Verification cog.""" bot.add_cog(Verification(bot)) - log.info("Cog loaded: Verification") diff --git a/bot/cogs/watchchannels/__init__.py b/bot/cogs/watchchannels/__init__.py index 86e1050fa..69d118df6 100644 --- a/bot/cogs/watchchannels/__init__.py +++ b/bot/cogs/watchchannels/__init__.py @@ -1,18 +1,9 @@ -import logging - -from discord.ext.commands import Bot - +from bot.bot import Bot from .bigbrother import BigBrother from .talentpool import TalentPool -log = logging.getLogger(__name__) - - def setup(bot: Bot) -> None: - """Monitoring cogs load.""" + """Load the BigBrother and TalentPool cogs.""" bot.add_cog(BigBrother(bot)) - log.info("Cog loaded: BigBrother") - bot.add_cog(TalentPool(bot)) - log.info("Cog loaded: TalentPool") diff --git a/bot/cogs/watchchannels/bigbrother.py b/bot/cogs/watchchannels/bigbrother.py index 49783bb09..306ed4c64 100644 --- a/bot/cogs/watchchannels/bigbrother.py +++ b/bot/cogs/watchchannels/bigbrother.py @@ -3,8 +3,9 @@ from collections import ChainMap from typing import Union from discord import User -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import Cog, Context, group +from bot.bot import Bot from bot.cogs.moderation.utils import post_infraction from bot.constants import Channels, MODERATION_ROLES, Webhooks from bot.decorators import with_role diff --git a/bot/cogs/watchchannels/talentpool.py b/bot/cogs/watchchannels/talentpool.py index 4ec42dcc1..cc8feeeee 100644 --- a/bot/cogs/watchchannels/talentpool.py +++ b/bot/cogs/watchchannels/talentpool.py @@ -4,9 +4,10 @@ from collections import ChainMap from typing import Union from discord import Color, Embed, Member, User -from discord.ext.commands import Bot, Cog, Context, group +from discord.ext.commands import Cog, Context, group from bot.api import ResponseCodeError +from bot.bot import Bot from bot.constants import Channels, Guild, MODERATION_ROLES, STAFF_ROLES, Webhooks from bot.decorators import with_role from bot.pagination import LinePaginator diff --git a/bot/cogs/watchchannels/watchchannel.py b/bot/cogs/watchchannels/watchchannel.py index 0bf75a924..bd0622554 100644 --- a/bot/cogs/watchchannels/watchchannel.py +++ b/bot/cogs/watchchannels/watchchannel.py @@ -10,9 +10,10 @@ from typing import Optional import dateutil.parser import discord from discord import Color, Embed, HTTPException, Message, Object, errors -from discord.ext.commands import BadArgument, Bot, Cog, Context +from discord.ext.commands import BadArgument, Cog, Context from bot.api import ResponseCodeError +from bot.bot import Bot from bot.cogs.moderation import ModLog from bot.constants import BigBrother as BigBrotherConfig, Guild as GuildConfig, Icons from bot.pagination import LinePaginator diff --git a/bot/cogs/wolfram.py b/bot/cogs/wolfram.py index ab0ed2472..5d6b4630b 100644 --- a/bot/cogs/wolfram.py +++ b/bot/cogs/wolfram.py @@ -7,8 +7,9 @@ import discord from dateutil.relativedelta import relativedelta from discord import Embed from discord.ext import commands -from discord.ext.commands import Bot, BucketType, Cog, Context, check, group +from discord.ext.commands import BucketType, Cog, Context, check, group +from bot.bot import Bot from bot.constants import Colours, STAFF_ROLES, Wolfram from bot.pagination import ImagePaginator from bot.utils.time import humanize_delta @@ -151,7 +152,7 @@ async def get_pod_pages(ctx: Context, bot: Bot, query: str) -> Optional[List[Tup class Wolfram(Cog): """Commands for interacting with the Wolfram|Alpha API.""" - def __init__(self, bot: commands.Bot): + def __init__(self, bot: Bot): self.bot = bot @group(name="wolfram", aliases=("wolf", "wa"), invoke_without_command=True) @@ -266,7 +267,6 @@ class Wolfram(Cog): await send_embed(ctx, message, color) -def setup(bot: commands.Bot) -> None: - """Wolfram cog load.""" +def setup(bot: Bot) -> None: + """Load the Wolfram cog.""" bot.add_cog(Wolfram(bot)) - log.info("Cog loaded: Wolfram") diff --git a/bot/constants.py b/bot/constants.py index 45f42cf81..8815ab983 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -236,6 +236,13 @@ class Colours(metaclass=YAMLGetter): soft_orange: int +class DuckPond(metaclass=YAMLGetter): + section = "duck_pond" + + threshold: int + custom_emojis: List[int] + + class Emojis(metaclass=YAMLGetter): section = "style" subsection = "emojis" @@ -244,21 +251,26 @@ class Emojis(metaclass=YAMLGetter): defcon_enabled: str # noqa: E704 defcon_updated: str # noqa: E704 - green_chevron: str - red_chevron: str - white_chevron: str - bb_message: str - status_online: str status_offline: str status_idle: str status_dnd: str + failmail: str + bullet: str new: str pencil: str cross_mark: str + ducky_yellow: int + ducky_blurple: int + ducky_regal: int + ducky_camo: int + ducky_ninja: int + ducky_devil: int + ducky_tube: int + upvotes: str comments: str user: str @@ -343,6 +355,7 @@ class Channels(metaclass=YAMLGetter): defcon: int devlog: int devtest: int + esoteric: int help_0: int help_1: int help_2: int @@ -377,6 +390,7 @@ class Webhooks(metaclass=YAMLGetter): talent_pool: int big_brother: int reddit: int + duck_pond: int class Roles(metaclass=YAMLGetter): @@ -453,6 +467,8 @@ class Reddit(metaclass=YAMLGetter): section = "reddit" subreddits: list + client_id: str + secret: str class Wolfram(metaclass=YAMLGetter): @@ -508,6 +524,30 @@ class RedirectOutput(metaclass=YAMLGetter): delete_delay: int +class Event(Enum): + """ + Event names. This does not include every event (for example, raw + events aren't here), but only events used in ModLog for now. + """ + + guild_channel_create = "guild_channel_create" + guild_channel_delete = "guild_channel_delete" + guild_channel_update = "guild_channel_update" + guild_role_create = "guild_role_create" + guild_role_delete = "guild_role_delete" + guild_role_update = "guild_role_update" + guild_update = "guild_update" + + member_join = "member_join" + member_remove = "member_remove" + member_ban = "member_ban" + member_unban = "member_unban" + member_update = "member_update" + + message_delete = "message_delete" + message_edit = "message_edit" + + # Debug mode DEBUG_MODE = True if 'local' in os.environ.get("SITE_URL", "local") else False @@ -579,27 +619,3 @@ ERROR_REPLIES = [ "Noooooo!!", "I can't believe you've done this", ] - - -class Event(Enum): - """ - Event names. This does not include every event (for example, raw - events aren't here), but only events used in ModLog for now. - """ - - guild_channel_create = "guild_channel_create" - guild_channel_delete = "guild_channel_delete" - guild_channel_update = "guild_channel_update" - guild_role_create = "guild_role_create" - guild_role_delete = "guild_role_delete" - guild_role_update = "guild_role_update" - guild_update = "guild_update" - - member_join = "member_join" - member_remove = "member_remove" - member_ban = "member_ban" - member_unban = "member_unban" - member_update = "member_update" - - message_delete = "message_delete" - message_edit = "message_edit" diff --git a/bot/converters.py b/bot/converters.py index cf0496541..8d2ab7eb8 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -1,8 +1,8 @@ import logging import re +import typing as t from datetime import datetime from ssl import CertificateError -from typing import Union import dateutil.parser import dateutil.tz @@ -15,6 +15,25 @@ from discord.ext.commands import BadArgument, Context, Converter log = logging.getLogger(__name__) +def allowed_strings(*values, preserve_case: bool = False) -> t.Callable[[str], str]: + """ + Return a converter which only allows arguments equal to one of the given values. + + Unless preserve_case is True, the argument is converted to lowercase. All values are then + expected to have already been given in lowercase too. + """ + def converter(arg: str) -> str: + if not preserve_case: + arg = arg.lower() + + if arg not in values: + raise BadArgument(f"Only the following values are allowed:\n```{', '.join(values)}```") + else: + return arg + + return converter + + class ValidPythonIdentifier(Converter): """ A converter that checks whether the given string is a valid Python identifier. @@ -70,7 +89,7 @@ class InfractionSearchQuery(Converter): """A converter that checks if the argument is a Discord user, and if not, falls back to a string.""" @staticmethod - async def convert(ctx: Context, arg: str) -> Union[discord.Member, str]: + async def convert(ctx: Context, arg: str) -> t.Union[discord.Member, str]: """Check if the argument is a Discord user, and if not, falls back to a string.""" try: maybe_snowflake = arg.strip("<@!>") diff --git a/bot/decorators.py b/bot/decorators.py index 935df4af0..2d18eaa6a 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -27,11 +27,23 @@ class InChannelCheckFailure(CheckFailure): super().__init__(f"Sorry, but you may only use this command within {channels_str}.") -def in_channel(*channels: int, bypass_roles: Container[int] = None) -> Callable: - """Checks that the message is in a whitelisted channel or optionally has a bypass role.""" +def in_channel( + *channels: int, + hidden_channels: Container[int] = None, + bypass_roles: Container[int] = None +) -> Callable: + """ + Checks that the message is in a whitelisted channel or optionally has a bypass role. + + Hidden channels are channels which will not be displayed in the InChannelCheckFailure error + message. + """ + hidden_channels = hidden_channels or [] + bypass_roles = bypass_roles or [] + def predicate(ctx: Context) -> bool: """In-channel checker predicate.""" - if ctx.channel.id in channels: + if ctx.channel.id in channels or ctx.channel.id in hidden_channels: log.debug(f"{ctx.author} tried to call the '{ctx.command.name}' command. " f"The command was used in a whitelisted channel.") return True diff --git a/bot/interpreter.py b/bot/interpreter.py index a42b45a2d..8b7268746 100644 --- a/bot/interpreter.py +++ b/bot/interpreter.py @@ -2,7 +2,9 @@ from code import InteractiveInterpreter from io import StringIO from typing import Any -from discord.ext.commands import Bot, Context +from discord.ext.commands import Context + +from bot.bot import Bot CODE_TEMPLATE = """ async def _func(): @@ -20,8 +22,8 @@ class Interpreter(InteractiveInterpreter): write_callable = None def __init__(self, bot: Bot): - _locals = {"bot": bot} - super().__init__(_locals) + locals_ = {"bot": bot} + super().__init__(locals_) async def run(self, code: str, ctx: Context, io: StringIO, *args, **kwargs) -> Any: """Execute the provided source code as the bot & return the output.""" diff --git a/bot/utils/scheduling.py b/bot/utils/scheduling.py index 08abd91d7..ee6c0a8e6 100644 --- a/bot/utils/scheduling.py +++ b/bot/utils/scheduling.py @@ -36,11 +36,15 @@ class Scheduler(metaclass=CogABCMeta): `task_data` is passed to `Scheduler._scheduled_expiration` """ if task_id in self.scheduled_tasks: + log.debug( + f"{self.cog_name}: did not schedule task #{task_id}; task was already scheduled." + ) return task: asyncio.Task = create_task(loop, self._scheduled_task(task_data)) self.scheduled_tasks[task_id] = task + log.debug(f"{self.cog_name}: scheduled task #{task_id}.") def cancel_task(self, task_id: str) -> None: """Un-schedules a task.""" @@ -51,7 +55,7 @@ class Scheduler(metaclass=CogABCMeta): return task.cancel() - log.debug(f"{self.cog_name}: Unscheduled {task_id}.") + log.debug(f"{self.cog_name}: unscheduled task #{task_id}.") del self.scheduled_tasks[task_id] diff --git a/bot/utils/time.py b/bot/utils/time.py index 2aea2c099..7416f36e0 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -111,3 +111,55 @@ async def wait_until(time: datetime.datetime, start: Optional[datetime.datetime] def format_infraction(timestamp: str) -> str: """Format an infraction timestamp to a more readable ISO 8601 format.""" return dateutil.parser.isoparse(timestamp).strftime(INFRACTION_FORMAT) + + +def format_infraction_with_duration( + expiry: Optional[str], + date_from: Optional[datetime.datetime] = None, + max_units: int = 2 +) -> Optional[str]: + """ + Format an infraction timestamp to a more readable ISO 8601 format WITH the duration. + + Returns a human-readable version of the duration between datetime.utcnow() and an expiry. + Unlike `humanize_delta`, this function will force the `precision` to be `seconds` by not passing it. + `max_units` specifies the maximum number of units of time to include (e.g. 1 may include days but not hours). + By default, max_units is 2. + """ + if not expiry: + return None + + date_from = date_from or datetime.datetime.utcnow() + date_to = dateutil.parser.isoparse(expiry).replace(tzinfo=None, microsecond=0) + + expiry_formatted = format_infraction(expiry) + + duration = humanize_delta(relativedelta(date_to, date_from), max_units=max_units) + duration_formatted = f" ({duration})" if duration else '' + + return f"{expiry_formatted}{duration_formatted}" + + +def until_expiration( + expiry: Optional[str], + now: Optional[datetime.datetime] = None, + max_units: int = 2 +) -> Optional[str]: + """ + Get the remaining time until infraction's expiration, in a human-readable version of the relativedelta. + + Returns a human-readable version of the remaining duration between datetime.utcnow() and an expiry. + Unlike `humanize_delta`, this function will force the `precision` to be `seconds` by not passing it. + `max_units` specifies the maximum number of units of time to include (e.g. 1 may include days but not hours). + By default, max_units is 2. + """ + if not expiry: + return None + + now = now or datetime.datetime.utcnow() + since = dateutil.parser.isoparse(expiry).replace(tzinfo=None, microsecond=0) + + if since < now: + return None + + return humanize_delta(relativedelta(since, now), max_units=max_units) diff --git a/config-default.yml b/config-default.yml index ee9f8a06b..6ae07da93 100644 --- a/config-default.yml +++ b/config-default.yml @@ -22,21 +22,26 @@ style: defcon_enabled: "<:defconenabled:470326274213150730>" defcon_updated: "<:defconsettingsupdated:470326274082996224>" - green_chevron: "<:greenchevron:418104310329769993>" - red_chevron: "<:redchevron:418112778184818698>" - white_chevron: "<:whitechevron:418110396973711363>" - bb_message: "<:bbmessage:476273120999636992>" - status_online: "<:status_online:470326272351010816>" status_idle: "<:status_idle:470326266625785866>" status_dnd: "<:status_dnd:470326272082313216>" status_offline: "<:status_offline:470326266537705472>" + failmail: "<:failmail:633660039931887616>" + bullet: "\u2022" pencil: "\u270F" new: "\U0001F195" cross_mark: "\u274C" + ducky_yellow: &DUCKY_YELLOW 574951975574175744 + ducky_blurple: &DUCKY_BLURPLE 574951975310065675 + ducky_regal: &DUCKY_REGAL 637883439185395712 + ducky_camo: &DUCKY_CAMO 637914731566596096 + ducky_ninja: &DUCKY_NINJA 637923502535606293 + ducky_devil: &DUCKY_DEVIL 637925314982576139 + ducky_tube: &DUCKY_TUBE 637881368008851456 + upvotes: "<:upvotes:638729835245731840>" comments: "<:comments:638729835073765387>" user: "<:user:638729835442602003>" @@ -105,6 +110,7 @@ guild: defcon: &DEFCON 464469101889454091 devlog: &DEVLOG 622895325144940554 devtest: &DEVTEST 414574275865870337 + esoteric: 470884583684964352 help_0: 303906576991780866 help_1: 303906556754395136 help_2: 303906514266226689 @@ -155,6 +161,7 @@ guild: talent_pool: 569145364800602132 big_brother: 569133704568373283 reddit: 635408384794951680 + duck_pond: 637821475327311927 filter: @@ -360,6 +367,8 @@ anti_malware: reddit: subreddits: - 'r/Python' + client_id: !ENV "REDDIT_CLIENT_ID" + secret: !ENV "REDDIT_SECRET" wolfram: @@ -389,5 +398,9 @@ redirect_output: delete_invocation: true delete_delay: 15 +duck_pond: + threshold: 5 + custom_emojis: [*DUCKY_YELLOW, *DUCKY_BLURPLE, *DUCKY_CAMO, *DUCKY_DEVIL, *DUCKY_NINJA, *DUCKY_REGAL, *DUCKY_TUBE] + config: required_keys: ['bot.token'] diff --git a/docker-compose.yml b/docker-compose.yml index f79fdba58..7281c7953 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -42,3 +42,5 @@ services: environment: BOT_TOKEN: ${BOT_TOKEN} BOT_API_KEY: badbot13m0n8f570f942013fc818f234916ca531 + REDDIT_CLIENT_ID: ${REDDIT_CLIENT_ID} + REDDIT_SECRET: ${REDDIT_SECRET} diff --git a/tests/README.md b/tests/README.md index 6ab9bc93e..d052de2f6 100644 --- a/tests/README.md +++ b/tests/README.md @@ -15,6 +15,7 @@ We are using the following modules and packages for our unit tests: To ensure the results you obtain on your personal machine are comparable to those generated in the Azure pipeline, please make sure to run your tests with the virtual environment defined by our [Pipfile](/Pipfile). To run your tests with `pipenv`, we've provided two "scripts" shortcuts: - `pipenv run test` will run `unittest` with `coverage.py` +- `pipenv run test path/to/test.py` will run a specific test. - `pipenv run report` will generate a coverage report of the tests you've run with `pipenv run test`. If you append the `-m` flag to this command, the report will include the lines and branches not covered by tests in addition to the test coverage report. If you want a coverage report, make sure to run the tests with `pipenv run test` *first*. diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py new file mode 100644 index 000000000..d07b2bce1 --- /dev/null +++ b/tests/bot/cogs/test_duck_pond.py @@ -0,0 +1,584 @@ +import asyncio +import logging +import typing +import unittest +from unittest.mock import MagicMock, patch + +import discord + +from bot import constants +from bot.cogs import duck_pond +from tests import base +from tests import helpers + +MODULE_PATH = "bot.cogs.duck_pond" + + +class DuckPondTests(base.LoggingTestCase): + """Tests for DuckPond functionality.""" + + @classmethod + def setUpClass(cls): + """Sets up the objects that only have to be initialized once.""" + cls.nonstaff_member = helpers.MockMember(name="Non-staffer") + + cls.staff_role = helpers.MockRole(name="Staff role", id=constants.STAFF_ROLES[0]) + cls.staff_member = helpers.MockMember(name="staffer", roles=[cls.staff_role]) + + cls.checkmark_emoji = "\N{White Heavy Check Mark}" + cls.thumbs_up_emoji = "\N{Thumbs Up Sign}" + cls.unicode_duck_emoji = "\N{Duck}" + cls.duck_pond_emoji = helpers.MockPartialEmoji(id=constants.DuckPond.custom_emojis[0]) + cls.non_duck_custom_emoji = helpers.MockPartialEmoji(id=123) + + def setUp(self): + """Sets up the objects that need to be refreshed before each test.""" + self.bot = helpers.MockBot(user=helpers.MockMember(id=46692)) + self.cog = duck_pond.DuckPond(bot=self.bot) + + def test_duck_pond_correctly_initializes(self): + """`__init__ should set `bot` and `webhook_id` attributes and schedule `fetch_webhook`.""" + bot = helpers.MockBot() + cog = MagicMock() + + duck_pond.DuckPond.__init__(cog, bot) + + self.assertEqual(cog.bot, bot) + self.assertEqual(cog.webhook_id, constants.Webhooks.duck_pond) + bot.loop.create_loop.called_once_with(cog.fetch_webhook()) + + def test_fetch_webhook_succeeds_without_connectivity_issues(self): + """The `fetch_webhook` method waits until `READY` event and sets the `webhook` attribute.""" + self.bot.fetch_webhook.return_value = "dummy webhook" + self.cog.webhook_id = 1 + + asyncio.run(self.cog.fetch_webhook()) + + self.bot.wait_until_ready.assert_called_once() + self.bot.fetch_webhook.assert_called_once_with(1) + self.assertEqual(self.cog.webhook, "dummy webhook") + + def test_fetch_webhook_logs_when_unable_to_fetch_webhook(self): + """The `fetch_webhook` method should log an exception when it fails to fetch the webhook.""" + self.bot.fetch_webhook.side_effect = discord.HTTPException(response=MagicMock(), message="Not found.") + self.cog.webhook_id = 1 + + log = logging.getLogger('bot.cogs.duck_pond') + with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: + asyncio.run(self.cog.fetch_webhook()) + + self.bot.wait_until_ready.assert_called_once() + self.bot.fetch_webhook.assert_called_once_with(1) + + self.assertEqual(len(log_watcher.records), 1) + + record = log_watcher.records[0] + self.assertEqual(record.levelno, logging.ERROR) + + def test_is_staff_returns_correct_values_based_on_instance_passed(self): + """The `is_staff` method should return correct values based on the instance passed.""" + test_cases = ( + (helpers.MockUser(name="User instance"), False), + (helpers.MockMember(name="Member instance without staff role"), False), + (helpers.MockMember(name="Member instance with staff role", roles=[self.staff_role]), True) + ) + + for user, expected_return in test_cases: + actual_return = self.cog.is_staff(user) + with self.subTest(user_type=user.name, expected_return=expected_return, actual_return=actual_return): + self.assertEqual(expected_return, actual_return) + + @helpers.async_test + async def test_has_green_checkmark_correctly_detects_presence_of_green_checkmark_emoji(self): + """The `has_green_checkmark` method should only return `True` if one is present.""" + test_cases = ( + ( + "No reactions", helpers.MockMessage(), False + ), + ( + "No green check mark reactions", + helpers.MockMessage(reactions=[ + helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), + helpers.MockReaction(emoji=self.thumbs_up_emoji, users=[self.bot.user]) + ]), + False + ), + ( + "Green check mark reaction, but not from our bot", + helpers.MockMessage(reactions=[ + helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), + helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member]) + ]), + False + ), + ( + "Green check mark reaction, with one from the bot", + helpers.MockMessage(reactions=[ + helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), + helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member, self.bot.user]) + ]), + True + ) + ) + + for description, message, expected_return in test_cases: + actual_return = await self.cog.has_green_checkmark(message) + with self.subTest( + test_case=description, + expected_return=expected_return, + actual_return=actual_return + ): + self.assertEqual(expected_return, actual_return) + + def test_send_webhook_correctly_passes_on_arguments(self): + """The `send_webhook` method should pass the arguments to the webhook correctly.""" + self.cog.webhook = helpers.MockAsyncWebhook() + + content = "fake content" + username = "fake username" + avatar_url = "fake avatar_url" + embed = "fake embed" + + asyncio.run(self.cog.send_webhook(content, username, avatar_url, embed)) + + self.cog.webhook.send.assert_called_once_with( + content=content, + username=username, + avatar_url=avatar_url, + embed=embed + ) + + def test_send_webhook_logs_when_sending_message_fails(self): + """The `send_webhook` method should catch a `discord.HTTPException` and log accordingly.""" + self.cog.webhook = helpers.MockAsyncWebhook() + self.cog.webhook.send.side_effect = discord.HTTPException(response=MagicMock(), message="Something failed.") + + log = logging.getLogger('bot.cogs.duck_pond') + with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: + asyncio.run(self.cog.send_webhook()) + + self.assertEqual(len(log_watcher.records), 1) + + record = log_watcher.records[0] + self.assertEqual(record.levelno, logging.ERROR) + + def _get_reaction( + self, + emoji: typing.Union[str, helpers.MockEmoji], + staff: int = 0, + nonstaff: int = 0 + ) -> helpers.MockReaction: + staffers = [helpers.MockMember(roles=[self.staff_role]) for _ in range(staff)] + nonstaffers = [helpers.MockMember() for _ in range(nonstaff)] + return helpers.MockReaction(emoji=emoji, users=staffers + nonstaffers) + + @helpers.async_test + async def test_count_ducks_correctly_counts_the_number_of_eligible_duck_emojis(self): + """The `count_ducks` method should return the number of unique staffers who gave a duck.""" + test_cases = ( + # Simple test cases + # A message without reactions should return 0 + ( + "No reactions", + helpers.MockMessage(), + 0 + ), + # A message with a non-duck reaction from a non-staffer should return 0 + ( + "Non-duck reaction from non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.thumbs_up_emoji, nonstaff=1)]), + 0 + ), + # A message with a non-duck reaction from a staffer should return 0 + ( + "Non-duck reaction from staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.non_duck_custom_emoji, staff=1)]), + 0 + ), + # A message with a non-duck reaction from a non-staffer and staffer should return 0 + ( + "Non-duck reaction from staffer + non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.thumbs_up_emoji, staff=1, nonstaff=1)]), + 0 + ), + # A message with a unicode duck reaction from a non-staffer should return 0 + ( + "Unicode Duck Reaction from non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, nonstaff=1)]), + 0 + ), + # A message with a unicode duck reaction from a staffer should return 1 + ( + "Unicode Duck Reaction from staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, staff=1)]), + 1 + ), + # A message with a unicode duck reaction from a non-staffer and staffer should return 1 + ( + "Unicode Duck Reaction from staffer + non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, staff=1, nonstaff=1)]), + 1 + ), + # A message with a duckpond duck reaction from a non-staffer should return 0 + ( + "Duckpond Duck Reaction from non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, nonstaff=1)]), + 0 + ), + # A message with a duckpond duck reaction from a staffer should return 1 + ( + "Duckpond Duck Reaction from staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=1)]), + 1 + ), + # A message with a duckpond duck reaction from a non-staffer and staffer should return 1 + ( + "Duckpond Duck Reaction from staffer + non-staffer", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=1, nonstaff=1)]), + 1 + ), + + # Complex test cases + # A message with duckpond duck reactions from 3 staffers and 2 non-staffers returns 3 + ( + "Duckpond Duck Reaction from 3 staffers + 2 non-staffers", + helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2)]), + 3 + ), + # A staffer with multiple duck reactions only counts once + ( + "Two different duck reactions from the same staffer", + helpers.MockMessage( + reactions=[ + helpers.MockReaction(emoji=self.duck_pond_emoji, users=[self.staff_member]), + helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.staff_member]), + ] + ), + 1 + ), + # A non-string emoji does not count (to test the `isinstance(reaction.emoji, str)` elif) + ( + "Reaction with non-Emoji/str emoij from 3 staffers + 2 non-staffers", + helpers.MockMessage(reactions=[self._get_reaction(emoji=100, staff=3, nonstaff=2)]), + 0 + ), + # We correctly sum when multiple reactions are provided. + ( + "Duckpond Duck Reaction from 3 staffers + 2 non-staffers", + helpers.MockMessage( + reactions=[ + self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2), + self._get_reaction(emoji=self.unicode_duck_emoji, staff=4, nonstaff=9), + ] + ), + 3 + 4 + ), + ) + + for description, message, expected_count in test_cases: + actual_count = await self.cog.count_ducks(message) + with self.subTest(test_case=description, expected_count=expected_count, actual_count=actual_count): + self.assertEqual(expected_count, actual_count) + + @helpers.async_test + async def test_relay_message_correctly_relays_content_and_attachments(self): + """The `relay_message` method should correctly relay message content and attachments.""" + send_webhook_path = f"{MODULE_PATH}.DuckPond.send_webhook" + send_attachments_path = f"{MODULE_PATH}.send_attachments" + + self.cog.webhook = helpers.MockAsyncWebhook() + + test_values = ( + (helpers.MockMessage(clean_content="", attachments=[]), False, False), + (helpers.MockMessage(clean_content="message", attachments=[]), True, False), + (helpers.MockMessage(clean_content="", attachments=["attachment"]), False, True), + (helpers.MockMessage(clean_content="message", attachments=["attachment"]), True, True), + ) + + for message, expect_webhook_call, expect_attachment_call in test_values: + with patch(send_webhook_path, new_callable=helpers.AsyncMock) as send_webhook: + with patch(send_attachments_path, new_callable=helpers.AsyncMock) as send_attachments: + with self.subTest(clean_content=message.clean_content, attachments=message.attachments): + await self.cog.relay_message(message) + + self.assertEqual(expect_webhook_call, send_webhook.called) + self.assertEqual(expect_attachment_call, send_attachments.called) + + message.add_reaction.assert_called_once_with(self.checkmark_emoji) + + @patch(f"{MODULE_PATH}.send_attachments", new_callable=helpers.AsyncMock) + @helpers.async_test + async def test_relay_message_handles_irretrievable_attachment_exceptions(self, send_attachments): + """The `relay_message` method should handle irretrievable attachments.""" + message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) + side_effects = (discord.errors.Forbidden(MagicMock(), ""), discord.errors.NotFound(MagicMock(), "")) + + self.cog.webhook = helpers.MockAsyncWebhook() + log = logging.getLogger("bot.cogs.duck_pond") + + for side_effect in side_effects: + send_attachments.side_effect = side_effect + with patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=helpers.AsyncMock) as send_webhook: + with self.subTest(side_effect=type(side_effect).__name__): + with self.assertNotLogs(logger=log, level=logging.ERROR): + await self.cog.relay_message(message) + + self.assertEqual(send_webhook.call_count, 2) + + @patch(f"{MODULE_PATH}.DuckPond.send_webhook", new_callable=helpers.AsyncMock) + @patch(f"{MODULE_PATH}.send_attachments", new_callable=helpers.AsyncMock) + @helpers.async_test + async def test_relay_message_handles_attachment_http_error(self, send_attachments, send_webhook): + """The `relay_message` method should handle irretrievable attachments.""" + message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) + + self.cog.webhook = helpers.MockAsyncWebhook() + log = logging.getLogger("bot.cogs.duck_pond") + + side_effect = discord.HTTPException(MagicMock(), "") + send_attachments.side_effect = side_effect + with self.subTest(side_effect=type(side_effect).__name__): + with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: + await self.cog.relay_message(message) + + send_webhook.assert_called_once_with( + content=message.clean_content, + username=message.author.display_name, + avatar_url=message.author.avatar_url + ) + + self.assertEqual(len(log_watcher.records), 1) + + record = log_watcher.records[0] + self.assertEqual(record.levelno, logging.ERROR) + + def _mock_payload(self, label: str, is_custom_emoji: bool, id_: int, emoji_name: str): + """Creates a mock `on_raw_reaction_add` payload with the specified emoji data.""" + payload = MagicMock(name=label) + payload.emoji.is_custom_emoji.return_value = is_custom_emoji + payload.emoji.id = id_ + payload.emoji.name = emoji_name + return payload + + @helpers.async_test + async def test_payload_has_duckpond_emoji_correctly_detects_relevant_emojis(self): + """The `on_raw_reaction_add` event handler should ignore irrelevant emojis.""" + test_values = ( + # Custom Emojis + ( + self._mock_payload( + label="Custom Duckpond Emoji", + is_custom_emoji=True, + id_=constants.DuckPond.custom_emojis[0], + emoji_name="" + ), + True + ), + ( + self._mock_payload( + label="Custom Non-Duckpond Emoji", + is_custom_emoji=True, + id_=123, + emoji_name="" + ), + False + ), + # Unicode Emojis + ( + self._mock_payload( + label="Unicode Duck Emoji", + is_custom_emoji=False, + id_=1, + emoji_name=self.unicode_duck_emoji + ), + True + ), + ( + self._mock_payload( + label="Unicode Non-Duck Emoji", + is_custom_emoji=False, + id_=1, + emoji_name=self.thumbs_up_emoji + ), + False + ), + ) + + for payload, expected_return in test_values: + actual_return = self.cog._payload_has_duckpond_emoji(payload) + with self.subTest(case=payload._mock_name, expected_return=expected_return, actual_return=actual_return): + self.assertEqual(expected_return, actual_return) + + @patch(f"{MODULE_PATH}.discord.utils.get") + @patch(f"{MODULE_PATH}.DuckPond._payload_has_duckpond_emoji", new=MagicMock(return_value=False)) + def test_on_raw_reaction_add_returns_early_with_payload_without_duck_emoji(self, utils_get): + """The `on_raw_reaction_add` method should return early if the payload does not contain a duck emoji.""" + self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_add(payload=MagicMock()))) + + # Ensure we've returned before making an unnecessary API call in the lines of code after the emoji check + utils_get.assert_not_called() + + def _raw_reaction_mocks(self, channel_id, message_id, user_id): + """Sets up mocks for tests of the `on_raw_reaction_add` event listener.""" + channel = helpers.MockTextChannel(id=channel_id) + self.bot.get_all_channels.return_value = (channel,) + + message = helpers.MockMessage(id=message_id) + + channel.fetch_message.return_value = message + + member = helpers.MockMember(id=user_id, roles=[self.staff_role]) + message.guild.members = (member,) + + payload = MagicMock(channel_id=channel_id, message_id=message_id, user_id=user_id) + + return channel, message, member, payload + + @helpers.async_test + async def test_on_raw_reaction_add_returns_for_bot_and_non_staff_members(self): + """The `on_raw_reaction_add` event handler should return for bot users or non-staff members.""" + channel_id = 1234 + message_id = 2345 + user_id = 3456 + + channel, message, _, payload = self._raw_reaction_mocks(channel_id, message_id, user_id) + + test_cases = ( + ("non-staff member", helpers.MockMember(id=user_id)), + ("bot staff member", helpers.MockMember(id=user_id, roles=[self.staff_role], bot=True)), + ) + + payload.emoji = self.duck_pond_emoji + + for description, member in test_cases: + message.guild.members = (member, ) + with self.subTest(test_case=description), patch(f"{MODULE_PATH}.DuckPond.has_green_checkmark") as checkmark: + checkmark.side_effect = AssertionError( + "Expected method to return before calling `self.has_green_checkmark`." + ) + self.assertIsNone(await self.cog.on_raw_reaction_add(payload)) + + # Check that we did make it past the payload checks + channel.fetch_message.assert_called_once() + channel.fetch_message.reset_mock() + + @patch(f"{MODULE_PATH}.DuckPond.is_staff") + @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) + def test_on_raw_reaction_add_returns_on_message_with_green_checkmark_placed_by_bot(self, count_ducks, is_staff): + """The `on_raw_reaction_add` event should return when the message has a green check mark placed by the bot.""" + channel_id = 31415926535 + message_id = 27182818284 + user_id = 16180339887 + + channel, message, member, payload = self._raw_reaction_mocks(channel_id, message_id, user_id) + + payload.emoji = helpers.MockPartialEmoji(name=self.unicode_duck_emoji) + payload.emoji.is_custom_emoji.return_value = False + + message.reactions = [helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.bot.user])] + + is_staff.return_value = True + count_ducks.side_effect = AssertionError("Expected method to return before calling `self.count_ducks`") + + self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_add(payload))) + + # Assert that we've made it past `self.is_staff` + is_staff.assert_called_once() + + @helpers.async_test + async def test_on_raw_reaction_add_does_not_relay_below_duck_threshold(self): + """The `on_raw_reaction_add` listener should not relay messages or attachments below the duck threshold.""" + test_cases = ( + (constants.DuckPond.threshold - 1, False), + (constants.DuckPond.threshold, True), + (constants.DuckPond.threshold + 1, True), + ) + + channel, message, member, payload = self._raw_reaction_mocks(channel_id=3, message_id=4, user_id=5) + + payload.emoji = self.duck_pond_emoji + + for duck_count, should_relay in test_cases: + with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=helpers.AsyncMock) as relay_message: + with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) as count_ducks: + count_ducks.return_value = duck_count + with self.subTest(duck_count=duck_count, should_relay=should_relay): + await self.cog.on_raw_reaction_add(payload) + + # Confirm that we've made it past counting + count_ducks.assert_called_once() + + # Did we relay a message? + has_relayed = relay_message.called + self.assertEqual(has_relayed, should_relay) + + if should_relay: + relay_message.assert_called_once_with(message) + + @helpers.async_test + async def test_on_raw_reaction_remove_prevents_removal_of_green_checkmark_depending_on_the_duck_count(self): + """The `on_raw_reaction_remove` listener prevents removal of the check mark on messages with enough ducks.""" + checkmark = helpers.MockPartialEmoji(name=self.checkmark_emoji) + + message = helpers.MockMessage(id=1234) + + channel = helpers.MockTextChannel(id=98765) + channel.fetch_message.return_value = message + + self.bot.get_all_channels.return_value = (channel, ) + + payload = MagicMock(channel_id=channel.id, message_id=message.id, emoji=checkmark) + + test_cases = ( + (constants.DuckPond.threshold - 1, False), + (constants.DuckPond.threshold, True), + (constants.DuckPond.threshold + 1, True), + ) + for duck_count, should_re_add_checkmark in test_cases: + with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=helpers.AsyncMock) as count_ducks: + count_ducks.return_value = duck_count + with self.subTest(duck_count=duck_count, should_re_add_checkmark=should_re_add_checkmark): + await self.cog.on_raw_reaction_remove(payload) + + # Check if we fetched the message + channel.fetch_message.assert_called_once_with(message.id) + + # Check if we actually counted the number of ducks + count_ducks.assert_called_once_with(message) + + has_re_added_checkmark = message.add_reaction.called + self.assertEqual(should_re_add_checkmark, has_re_added_checkmark) + + if should_re_add_checkmark: + message.add_reaction.assert_called_once_with(self.checkmark_emoji) + message.add_reaction.reset_mock() + + # reset mocks + channel.fetch_message.reset_mock() + message.reset_mock() + + def test_on_raw_reaction_remove_ignores_removal_of_non_checkmark_reactions(self): + """The `on_raw_reaction_remove` listener should ignore the removal of non-check mark emojis.""" + channel = helpers.MockTextChannel(id=98765) + + channel.fetch_message.side_effect = AssertionError( + "Expected method to return before calling `channel.fetch_message`" + ) + + self.bot.get_all_channels.return_value = (channel, ) + + payload = MagicMock(emoji=helpers.MockPartialEmoji(name=self.thumbs_up_emoji), channel_id=channel.id) + + self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_remove(payload))) + + channel.fetch_message.assert_not_called() + + +class DuckPondSetupTests(unittest.TestCase): + """Tests setup of the `DuckPond` cog.""" + + def test_setup(self): + """Setup of the extension should call add_cog.""" + bot = helpers.MockBot() + duck_pond.setup(bot) + bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/test_security.py b/tests/bot/cogs/test_security.py index efa7a50b1..9d1a62f7e 100644 --- a/tests/bot/cogs/test_security.py +++ b/tests/bot/cogs/test_security.py @@ -1,4 +1,3 @@ -import logging import unittest from unittest.mock import MagicMock @@ -49,11 +48,7 @@ class SecurityCogLoadTests(unittest.TestCase): """Tests loading the `Security` cog.""" def test_security_cog_load(self): - """Cog loading logs a message at `INFO` level.""" + """Setup of the extension should call add_cog.""" bot = MagicMock() - with self.assertLogs(logger='bot.cogs.security', level=logging.INFO) as cm: - security.setup(bot) - bot.add_cog.assert_called_once() - - [line] = cm.output - self.assertIn("Cog loaded: Security", line) + security.setup(bot) + bot.add_cog.assert_called_once() diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py index 3276cf5a5..a54b839d7 100644 --- a/tests/bot/cogs/test_token_remover.py +++ b/tests/bot/cogs/test_token_remover.py @@ -125,11 +125,7 @@ class TokenRemoverSetupTests(unittest.TestCase): """Tests setup of the `TokenRemover` cog.""" def test_setup(self): - """Setup of the cog should log a message at `INFO` level.""" + """Setup of the extension should call add_cog.""" bot = MockBot() - with self.assertLogs(logger='bot.cogs.token_remover', level=logging.INFO) as cm: - setup_cog(bot) - - [line] = cm.output + setup_cog(bot) bot.add_cog.assert_called_once() - self.assertIn("Cog loaded: TokenRemover", line) diff --git a/tests/bot/utils/test_time.py b/tests/bot/utils/test_time.py new file mode 100644 index 000000000..69f35f2f5 --- /dev/null +++ b/tests/bot/utils/test_time.py @@ -0,0 +1,162 @@ +import asyncio +import unittest +from datetime import datetime, timezone +from unittest.mock import patch + +from dateutil.relativedelta import relativedelta + +from bot.utils import time +from tests.helpers import AsyncMock + + +class TimeTests(unittest.TestCase): + """Test helper functions in bot.utils.time.""" + + def test_humanize_delta_handle_unknown_units(self): + """humanize_delta should be able to handle unknown units, and will not abort.""" + # Does not abort for unknown units, as the unit name is checked + # against the attribute of the relativedelta instance. + self.assertEqual(time.humanize_delta(relativedelta(days=2, hours=2), 'elephants', 2), '2 days and 2 hours') + + def test_humanize_delta_handle_high_units(self): + """humanize_delta should be able to handle very high units.""" + # Very high maximum units, but it only ever iterates over + # each value the relativedelta might have. + self.assertEqual(time.humanize_delta(relativedelta(days=2, hours=2), 'hours', 20), '2 days and 2 hours') + + def test_humanize_delta_should_normal_usage(self): + """Testing humanize delta.""" + test_cases = ( + (relativedelta(days=2), 'seconds', 1, '2 days'), + (relativedelta(days=2, hours=2), 'seconds', 2, '2 days and 2 hours'), + (relativedelta(days=2, hours=2), 'seconds', 1, '2 days'), + (relativedelta(days=2, hours=2), 'days', 2, '2 days'), + ) + + for delta, precision, max_units, expected in test_cases: + with self.subTest(delta=delta, precision=precision, max_units=max_units, expected=expected): + self.assertEqual(time.humanize_delta(delta, precision, max_units), expected) + + def test_humanize_delta_raises_for_invalid_max_units(self): + """humanize_delta should raises ValueError('max_units must be positive') for invalid max_units.""" + test_cases = (-1, 0) + + for max_units in test_cases: + with self.subTest(max_units=max_units), self.assertRaises(ValueError) as error: + time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units) + self.assertEqual(str(error), 'max_units must be positive') + + def test_parse_rfc1123(self): + """Testing parse_rfc1123.""" + self.assertEqual( + time.parse_rfc1123('Sun, 15 Sep 2019 12:00:00 GMT'), + datetime(2019, 9, 15, 12, 0, 0, tzinfo=timezone.utc) + ) + + def test_format_infraction(self): + """Testing format_infraction.""" + self.assertEqual(time.format_infraction('2019-12-12T00:01:00Z'), '2019-12-12 00:01') + + @patch('asyncio.sleep', new_callable=AsyncMock) + def test_wait_until(self, mock): + """Testing wait_until.""" + start = datetime(2019, 1, 1, 0, 0) + then = datetime(2019, 1, 1, 0, 10) + + # No return value + self.assertIs(asyncio.run(time.wait_until(then, start)), None) + + mock.assert_called_once_with(10 * 60) + + def test_format_infraction_with_duration_none_expiry(self): + """format_infraction_with_duration should work for None expiry.""" + test_cases = ( + (None, None, None, None), + + # To make sure that date_from and max_units are not touched + (None, 'Why hello there!', None, None), + (None, None, float('inf'), None), + (None, 'Why hello there!', float('inf'), None), + ) + + for expiry, date_from, max_units, expected in test_cases: + with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected): + self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) + + def test_format_infraction_with_duration_custom_units(self): + """format_infraction_with_duration should work for custom max_units.""" + test_cases = ( + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6, + '2019-12-12 00:01 (11 hours, 55 minutes and 55 seconds)'), + ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 20, + '2019-11-23 20:09 (6 months, 28 days, 23 hours and 54 minutes)') + ) + + for expiry, date_from, max_units, expected in test_cases: + with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected): + self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) + + def test_format_infraction_with_duration_normal_usage(self): + """format_infraction_with_duration should work for normal usage, across various durations.""" + test_cases = ( + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 2, '2019-12-12 00:01 (12 hours and 55 seconds)'), + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 1, '2019-12-12 00:01 (12 hours)'), + ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), 2, '2019-12-12 00:00 (1 minute)'), + ('2019-11-23T20:09:00Z', datetime(2019, 11, 15, 20, 15), 2, '2019-11-23 20:09 (7 days and 23 hours)'), + ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 2, '2019-11-23 20:09 (6 months and 28 days)'), + ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 20, 53), 2, '2019-11-23 20:58 (5 minutes)'), + ('2019-11-24T00:00:00Z', datetime(2019, 11, 23, 23, 59, 0), 2, '2019-11-24 00:00 (1 minute)'), + ('2019-11-23T23:59:00Z', datetime(2017, 7, 21, 23, 0), 2, '2019-11-23 23:59 (2 years and 4 months)'), + ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), 2, + '2019-11-23 23:59 (9 minutes and 55 seconds)'), + (None, datetime(2019, 11, 23, 23, 49, 5), 2, None), + ) + + for expiry, date_from, max_units, expected in test_cases: + with self.subTest(expiry=expiry, date_from=date_from, max_units=max_units, expected=expected): + self.assertEqual(time.format_infraction_with_duration(expiry, date_from, max_units), expected) + + def test_until_expiration_with_duration_none_expiry(self): + """until_expiration should work for None expiry.""" + test_cases = ( + (None, None, None, None), + + # To make sure that now and max_units are not touched + (None, 'Why hello there!', None, None), + (None, None, float('inf'), None), + (None, 'Why hello there!', float('inf'), None), + ) + + for expiry, now, max_units, expected in test_cases: + with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected): + self.assertEqual(time.until_expiration(expiry, now, max_units), expected) + + def test_until_expiration_with_duration_custom_units(self): + """until_expiration should work for custom max_units.""" + test_cases = ( + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 5, 5), 6, '11 hours, 55 minutes and 55 seconds'), + ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 20, '6 months, 28 days, 23 hours and 54 minutes') + ) + + for expiry, now, max_units, expected in test_cases: + with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected): + self.assertEqual(time.until_expiration(expiry, now, max_units), expected) + + def test_until_expiration_normal_usage(self): + """until_expiration should work for normal usage, across various durations.""" + test_cases = ( + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 2, '12 hours and 55 seconds'), + ('2019-12-12T00:01:00Z', datetime(2019, 12, 11, 12, 0, 5), 1, '12 hours'), + ('2019-12-12T00:00:00Z', datetime(2019, 12, 11, 23, 59), 2, '1 minute'), + ('2019-11-23T20:09:00Z', datetime(2019, 11, 15, 20, 15), 2, '7 days and 23 hours'), + ('2019-11-23T20:09:00Z', datetime(2019, 4, 25, 20, 15), 2, '6 months and 28 days'), + ('2019-11-23T20:58:00Z', datetime(2019, 11, 23, 20, 53), 2, '5 minutes'), + ('2019-11-24T00:00:00Z', datetime(2019, 11, 23, 23, 59, 0), 2, '1 minute'), + ('2019-11-23T23:59:00Z', datetime(2017, 7, 21, 23, 0), 2, '2 years and 4 months'), + ('2019-11-23T23:59:00Z', datetime(2019, 11, 23, 23, 49, 5), 2, '9 minutes and 55 seconds'), + (None, datetime(2019, 11, 23, 23, 49, 5), 2, None), + ) + + for expiry, now, max_units, expected in test_cases: + with self.subTest(expiry=expiry, now=now, max_units=max_units, expected=expected): + self.assertEqual(time.until_expiration(expiry, now, max_units), expected) diff --git a/tests/helpers.py b/tests/helpers.py index 8a14aeef4..5df796c23 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -10,7 +10,9 @@ import unittest.mock from typing import Any, Iterable, Optional import discord -from discord.ext.commands import Bot, Context +from discord.ext.commands import Context + +from bot.bot import Bot for logger in logging.Logger.manager.loggerDict.values(): @@ -120,8 +122,80 @@ class AsyncMock(CustomMockMixin, unittest.mock.MagicMock): Python 3.8 will introduce an AsyncMock class in the standard library that will have some more features; this stand-in only overwrites the `__call__` method to an async version. """ + async def __call__(self, *args, **kwargs): - return super(AsyncMock, self).__call__(*args, **kwargs) + return super().__call__(*args, **kwargs) + + +class AsyncIteratorMock: + """ + A class to mock asynchronous iterators. + + This allows async for, which is used in certain Discord.py objects. For example, + an async iterator is returned by the Reaction.users() method. + """ + + def __init__(self, iterable: Iterable = None): + if iterable is None: + iterable = [] + + self.iter = iter(iterable) + self.iterable = iterable + + self.call_count = 0 + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self.iter) + except StopIteration: + raise StopAsyncIteration + + def __call__(self): + """ + Keeps track of the number of times an instance has been called. + + This is useful, since it typically shows that the iterator has actually been used somewhere after we have + instantiated the mock for an attribute that normally returns an iterator when called. + """ + self.call_count += 1 + return self + + @property + def return_value(self): + """Makes `self.iterable` accessible as self.return_value.""" + return self.iterable + + @return_value.setter + def return_value(self, iterable): + """Stores the `return_value` as `self.iterable` and its iterator as `self.iter`.""" + self.iter = iter(iterable) + self.iterable = iterable + + def assert_called(self): + """Asserts if the AsyncIteratorMock instance has been called at least once.""" + if self.call_count == 0: + raise AssertionError("Expected AsyncIteratorMock to have been called.") + + def assert_called_once(self): + """Asserts if the AsyncIteratorMock instance has been called exactly once.""" + if self.call_count != 1: + raise AssertionError( + f"Expected AsyncIteratorMock to have been called once. Called {self.call_count} times." + ) + + def assert_not_called(self): + """Asserts if the AsyncIteratorMock instance has not been called.""" + if self.call_count != 0: + raise AssertionError( + f"Expected AsyncIteratorMock to not have been called once. Called {self.call_count} times." + ) + + def reset_mock(self): + """Resets the call count, but not the return value or iterator.""" + self.call_count = 0 # Create a guild instance to get a realistic Mock of `discord.Guild` @@ -220,7 +294,7 @@ class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin information, see the `MockGuild` docstring. """ def __init__(self, roles: Optional[Iterable[MockRole]] = None, **kwargs) -> None: - default_kwargs = {'name': 'member', 'id': next(self.discord_id)} + default_kwargs = {'name': 'member', 'id': next(self.discord_id), 'bot': False} super().__init__(spec_set=member_instance, **collections.ChainMap(kwargs, default_kwargs)) self.roles = [MockRole(name="@everyone", position=1, id=0)] @@ -231,6 +305,25 @@ class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin self.mention = f"@{self.name}" +# Create a User instance to get a realistic Mock of `discord.User` +user_instance = discord.User(data=unittest.mock.MagicMock(), state=unittest.mock.MagicMock()) + + +class MockUser(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): + """ + A Mock subclass to mock User objects. + + Instances of this class will follow the specifications of `discord.User` instances. For more + information, see the `MockGuild` docstring. + """ + def __init__(self, **kwargs) -> None: + default_kwargs = {'name': 'user', 'id': next(self.discord_id), 'bot': False} + super().__init__(spec_set=user_instance, **collections.ChainMap(kwargs, default_kwargs)) + + if 'mention' not in kwargs: + self.mention = f"@{self.name}" + + # Create a Bot instance to get a realistic MagicMock of `discord.ext.commands.Bot` bot_instance = Bot(command_prefix=unittest.mock.MagicMock()) bot_instance.http_session = None @@ -244,6 +337,7 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances. For more information, see the `MockGuild` docstring. """ + def __init__(self, **kwargs) -> None: super().__init__(spec_set=bot_instance, **kwargs) @@ -281,6 +375,7 @@ class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin): Instances of this class will follow the specifications of `discord.TextChannel` instances. For more information, see the `MockGuild` docstring. """ + def __init__(self, name: str = 'channel', channel_id: int = 1, **kwargs) -> None: default_kwargs = {'id': next(self.discord_id), 'name': 'channel', 'guild': MockGuild()} super().__init__(spec_set=channel_instance, **collections.ChainMap(kwargs, default_kwargs)) @@ -322,6 +417,7 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.ext.commands.Context` instances. For more information, see the `MockGuild` docstring. """ + def __init__(self, **kwargs) -> None: super().__init__(spec_set=context_instance, **kwargs) self.bot = kwargs.get('bot', MockBot()) @@ -330,6 +426,20 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock): self.channel = kwargs.get('channel', MockTextChannel()) +attachment_instance = discord.Attachment(data=unittest.mock.MagicMock(id=1), state=unittest.mock.MagicMock()) + + +class MockAttachment(CustomMockMixin, unittest.mock.MagicMock): + """ + A MagicMock subclass to mock Attachment objects. + + Instances of this class will follow the specifications of `discord.Attachment` instances. For + more information, see the `MockGuild` docstring. + """ + def __init__(self, **kwargs) -> None: + super().__init__(spec_set=attachment_instance, **kwargs) + + class MockMessage(CustomMockMixin, unittest.mock.MagicMock): """ A MagicMock subclass to mock Message objects. @@ -337,8 +447,10 @@ class MockMessage(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.Message` instances. For more information, see the `MockGuild` docstring. """ + def __init__(self, **kwargs) -> None: - super().__init__(spec_set=message_instance, **kwargs) + default_kwargs = {'attachments': []} + super().__init__(spec_set=message_instance, **collections.ChainMap(kwargs, default_kwargs)) self.author = kwargs.get('author', MockMember()) self.channel = kwargs.get('channel', MockTextChannel()) @@ -354,6 +466,7 @@ class MockEmoji(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.Emoji` instances. For more information, see the `MockGuild` docstring. """ + def __init__(self, **kwargs) -> None: super().__init__(spec_set=emoji_instance, **kwargs) self.guild = kwargs.get('guild', MockGuild()) @@ -369,6 +482,7 @@ class MockPartialEmoji(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.PartialEmoji` instances. For more information, see the `MockGuild` docstring. """ + def __init__(self, **kwargs) -> None: super().__init__(spec_set=partial_emoji_instance, **kwargs) @@ -383,7 +497,31 @@ class MockReaction(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.Reaction` instances. For more information, see the `MockGuild` docstring. """ + def __init__(self, **kwargs) -> None: super().__init__(spec_set=reaction_instance, **kwargs) self.emoji = kwargs.get('emoji', MockEmoji()) self.message = kwargs.get('message', MockMessage()) + self.users = AsyncIteratorMock(kwargs.get('users', [])) + + +webhook_instance = discord.Webhook(data=unittest.mock.MagicMock(), adapter=unittest.mock.MagicMock()) + + +class MockAsyncWebhook(CustomMockMixin, unittest.mock.MagicMock): + """ + A MagicMock subclass to mock Webhook objects using an AsyncWebhookAdapter. + + Instances of this class will follow the specifications of `discord.Webhook` instances. For + more information, see the `MockGuild` docstring. + """ + + def __init__(self, **kwargs) -> None: + super().__init__(spec_set=webhook_instance, **kwargs) + + # Because Webhooks can also use a synchronous "WebhookAdapter", the methods are not defined + # as coroutines. That's why we need to set the methods manually. + self.send = AsyncMock() + self.edit = AsyncMock() + self.delete = AsyncMock() + self.execute = AsyncMock() |