diff options
46 files changed, 2887 insertions, 881 deletions
| diff --git a/LICENSE-THIRD-PARTY b/LICENSE-THIRD-PARTY index eacd9b952..ab715630d 100644 --- a/LICENSE-THIRD-PARTY +++ b/LICENSE-THIRD-PARTY @@ -35,6 +35,36 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.  --------------------------------------------------------------------------------------------------- +                                       BSD 2-Clause License +Applies to: +    - Copyright (c) 2007-2020 by the Sphinx team (see AUTHORS file). All rights reserved. +        - bot/cogs/doc/inventory_parser.py: _load_v1, _load_v2 and ZlibStreamReader.__aiter__. +--------------------------------------------------------------------------------------------------- + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +* Redistributions of source code must retain the above copyright +  notice, this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright +  notice, this list of conditions and the following disclaimer in the +  documentation and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +---------------------------------------------------------------------------------------------------                             PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2  Applies to:      - Copyright © 2001-2020 Python Software Foundation. All rights reserved. @@ -20,15 +20,13 @@ emoji = "~=0.6"  feedparser = "~=5.2"  fuzzywuzzy = "~=0.17"  lxml = "~=4.4" -markdownify = "==0.5.3" +markdownify = "==0.6.1"  more_itertools = "~=8.2"  python-dateutil = "~=2.8"  python-frontmatter = "~=1.0.0"  pyyaml = "~=5.1"  regex = "==2021.4.4" -requests = "~=2.22"  sentry-sdk = "~=0.19" -sphinx = "~=2.2"  statsd = "~=3.3"  [dev-packages] diff --git a/Pipfile.lock b/Pipfile.lock index d6792ac35..1e1a8167b 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@  {      "_meta": {          "hash": { -            "sha256": "fc3421fc4c95d73b620f2b8b0a7dea288d4fc559e0d288ed4ad6cf4eb312f630" +            "sha256": "e35c9bad81b01152ad3e10b85f1abf5866aa87b9d87e03bc30bdb9d37668ccae"          },          "pipfile-spec": 6,          "requires": { @@ -99,13 +99,6 @@              "markers": "python_version >= '3.6'",              "version": "==3.3.1"          }, -        "alabaster": { -            "hashes": [ -                "sha256:446438bdcca0e05bd45ea2de1668c1d9b032e1a9154c2c259092d77031ddd359", -                "sha256:a661d72d58e6ea8a57f7a86e37d86716863ee5e92788398526d58b26a4e4dc02" -            ], -            "version": "==0.7.12" -        },          "arrow": {              "hashes": [                  "sha256:3515630f11a15c61dcb4cdd245883270dd334c83f3e639824e65a4b79cc48543", @@ -142,14 +135,6 @@              "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'",              "version": "==20.3.0"          }, -        "babel": { -            "hashes": [ -                "sha256:9d35c22fcc79893c3ecc85ac4a56cde1ecf3f19c540bba0922308a6c06ca6fa5", -                "sha256:da031ab54472314f210b0adcff1588ee5d1d1d0ba4dbd07b94dba82bde791e05" -            ], -            "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", -            "version": "==2.9.0" -        },          "beautifulsoup4": {              "hashes": [                  "sha256:4c98143716ef1cb40bf7f39a8e3eec8f8b009509e74904ba3a7b315431577e35", @@ -221,7 +206,6 @@                  "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b",                  "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"              ], -            "index": "pypi",              "markers": "sys_platform == 'win32'",              "version": "==0.4.4"          }, @@ -249,14 +233,6 @@              "index": "pypi",              "version": "==1.6.0"          }, -        "docutils": { -            "hashes": [ -                "sha256:a71042bb7207c03d5647f280427f14bfbd1a65c9eb84f4b341d85fafb6bb4bdf", -                "sha256:e2ffeea817964356ba4470efba7c2f42b6b0de0b04e66378507e3e2504bbff4c" -            ], -            "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'", -            "version": "==0.17" -        },          "emoji": {              "hashes": [                  "sha256:e42da4f8d648f8ef10691bc246f682a1ec6b18373abfd9be10ec0b398823bd11" @@ -345,27 +321,11 @@          },          "idna": {              "hashes": [ -                "sha256:b307872f855b18632ce0c21c5e45be78c0ea7ae4c15c828c20788b26921eb3f6", -                "sha256:b97d804b1e9b523befed77c48dacec60e6dcb0b5391d57af6a65a312a90648c0" +                "sha256:5205d03e7bcbb919cc9c19885f9920d622ca52448306f2377daede5cf3faac16", +                "sha256:c5b02147e01ea9920e6b0a3f1f7bb833612d507592c837a6c49552768f4054e1"              ], -            "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", -            "version": "==2.10" -        }, -        "imagesize": { -            "hashes": [ -                "sha256:6965f19a6a2039c7d48bca7dba2473069ff854c36ae6f19d2cde309d998228a1", -                "sha256:b1f6b5a4eab1f73479a50fb79fcf729514a900c341d8503d62a62dbc4127a2b1" -            ], -            "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", -            "version": "==1.2.0" -        }, -        "jinja2": { -            "hashes": [ -                "sha256:03e47ad063331dd6a3f04a43eddca8a966a26ba0c5b7207a9a9e4e08f1b29419", -                "sha256:a6d58433de0ae800347cab1fa3043cebbabe8baa9d29e668f1c768cb87a333c6" -            ], -            "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'", -            "version": "==2.11.3" +            "markers": "python_version >= '3.4'", +            "version": "==3.1"          },          "lxml": {              "hashes": [ @@ -411,69 +371,11 @@          },          "markdownify": {              "hashes": [ -                "sha256:30be8340724e706c9e811c27fe8c1542cf74a15b46827924fff5c54b40dd9b0d", -                "sha256:a69588194fd76634f0139d6801b820fd652dc5eeba9530e90d323dfdc0155252" +                "sha256:31d7c13ac2ada8bfc7535a25fee6622ca720e1b5f2d4a9cbc429d167c21f886d", +                "sha256:7489fd5c601536996a376c4afbcd1dd034db7690af807120681461e82fbc0acc"              ],              "index": "pypi", -            "version": "==0.5.3" -        }, -        "markupsafe": { -            "hashes": [ -                "sha256:00bc623926325b26bb9605ae9eae8a215691f33cae5df11ca5424f06f2d1f473", -                "sha256:09027a7803a62ca78792ad89403b1b7a73a01c8cb65909cd876f7fcebd79b161", -                "sha256:09c4b7f37d6c648cb13f9230d847adf22f8171b1ccc4d5682398e77f40309235", -                "sha256:1027c282dad077d0bae18be6794e6b6b8c91d58ed8a8d89a89d59693b9131db5", -                "sha256:13d3144e1e340870b25e7b10b98d779608c02016d5184cfb9927a9f10c689f42", -                "sha256:195d7d2c4fbb0ee8139a6cf67194f3973a6b3042d742ebe0a9ed36d8b6f0c07f", -                "sha256:22c178a091fc6630d0d045bdb5992d2dfe14e3259760e713c490da5323866c39", -                "sha256:24982cc2533820871eba85ba648cd53d8623687ff11cbb805be4ff7b4c971aff", -                "sha256:29872e92839765e546828bb7754a68c418d927cd064fd4708fab9fe9c8bb116b", -                "sha256:2beec1e0de6924ea551859edb9e7679da6e4870d32cb766240ce17e0a0ba2014", -                "sha256:3b8a6499709d29c2e2399569d96719a1b21dcd94410a586a18526b143ec8470f", -                "sha256:43a55c2930bbc139570ac2452adf3d70cdbb3cfe5912c71cdce1c2c6bbd9c5d1", -                "sha256:46c99d2de99945ec5cb54f23c8cd5689f6d7177305ebff350a58ce5f8de1669e", -                "sha256:500d4957e52ddc3351cabf489e79c91c17f6e0899158447047588650b5e69183", -                "sha256:535f6fc4d397c1563d08b88e485c3496cf5784e927af890fb3c3aac7f933ec66", -                "sha256:596510de112c685489095da617b5bcbbac7dd6384aeebeda4df6025d0256a81b", -                "sha256:62fe6c95e3ec8a7fad637b7f3d372c15ec1caa01ab47926cfdf7a75b40e0eac1", -                "sha256:6788b695d50a51edb699cb55e35487e430fa21f1ed838122d722e0ff0ac5ba15", -                "sha256:6dd73240d2af64df90aa7c4e7481e23825ea70af4b4922f8ede5b9e35f78a3b1", -                "sha256:6f1e273a344928347c1290119b493a1f0303c52f5a5eae5f16d74f48c15d4a85", -                "sha256:6fffc775d90dcc9aed1b89219549b329a9250d918fd0b8fa8d93d154918422e1", -                "sha256:717ba8fe3ae9cc0006d7c451f0bb265ee07739daf76355d06366154ee68d221e", -                "sha256:79855e1c5b8da654cf486b830bd42c06e8780cea587384cf6545b7d9ac013a0b", -                "sha256:7c1699dfe0cf8ff607dbdcc1e9b9af1755371f92a68f706051cc8c37d447c905", -                "sha256:7fed13866cf14bba33e7176717346713881f56d9d2bcebab207f7a036f41b850", -                "sha256:84dee80c15f1b560d55bcfe6d47b27d070b4681c699c572af2e3c7cc90a3b8e0", -                "sha256:88e5fcfb52ee7b911e8bb6d6aa2fd21fbecc674eadd44118a9cc3863f938e735", -                "sha256:8defac2f2ccd6805ebf65f5eeb132adcf2ab57aa11fdf4c0dd5169a004710e7d", -                "sha256:98bae9582248d6cf62321dcb52aaf5d9adf0bad3b40582925ef7c7f0ed85fceb", -                "sha256:98c7086708b163d425c67c7a91bad6e466bb99d797aa64f965e9d25c12111a5e", -                "sha256:9add70b36c5666a2ed02b43b335fe19002ee5235efd4b8a89bfcf9005bebac0d", -                "sha256:9bf40443012702a1d2070043cb6291650a0841ece432556f784f004937f0f32c", -                "sha256:a6a744282b7718a2a62d2ed9d993cad6f5f585605ad352c11de459f4108df0a1", -                "sha256:acf08ac40292838b3cbbb06cfe9b2cb9ec78fce8baca31ddb87aaac2e2dc3bc2", -                "sha256:ade5e387d2ad0d7ebf59146cc00c8044acbd863725f887353a10df825fc8ae21", -                "sha256:b00c1de48212e4cc9603895652c5c410df699856a2853135b3967591e4beebc2", -                "sha256:b1282f8c00509d99fef04d8ba936b156d419be841854fe901d8ae224c59f0be5", -                "sha256:b1dba4527182c95a0db8b6060cc98ac49b9e2f5e64320e2b56e47cb2831978c7", -                "sha256:b2051432115498d3562c084a49bba65d97cf251f5a331c64a12ee7e04dacc51b", -                "sha256:b7d644ddb4dbd407d31ffb699f1d140bc35478da613b441c582aeb7c43838dd8", -                "sha256:ba59edeaa2fc6114428f1637ffff42da1e311e29382d81b339c1817d37ec93c6", -                "sha256:bf5aa3cbcfdf57fa2ee9cd1822c862ef23037f5c832ad09cfea57fa846dec193", -                "sha256:c8716a48d94b06bb3b2524c2b77e055fb313aeb4ea620c8dd03a105574ba704f", -                "sha256:caabedc8323f1e93231b52fc32bdcde6db817623d33e100708d9a68e1f53b26b", -                "sha256:cd5df75523866410809ca100dc9681e301e3c27567cf498077e8551b6d20e42f", -                "sha256:cdb132fc825c38e1aeec2c8aa9338310d29d337bebbd7baa06889d09a60a1fa2", -                "sha256:d53bc011414228441014aa71dbec320c66468c1030aae3a6e29778a3382d96e5", -                "sha256:d73a845f227b0bfe8a7455ee623525ee656a9e2e749e4742706d80a6065d5e2c", -                "sha256:d9be0ba6c527163cbed5e0857c451fcd092ce83947944d6c14bc95441203f032", -                "sha256:e249096428b3ae81b08327a63a485ad0878de3fb939049038579ac0ef61e17e7", -                "sha256:e8313f01ba26fbbe36c7be1966a7b7424942f670f38e666995b88d012765b9be", -                "sha256:feb7b34d6325451ef96bc0e36e1a6c0c1c64bc1fbec4b854f4529e51887b1621" -            ], -            "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", -            "version": "==1.1.1" +            "version": "==0.6.1"          },          "more-itertools": {              "hashes": [ @@ -533,14 +435,6 @@              "markers": "python_version >= '3.5'",              "version": "==4.0.2"          }, -        "packaging": { -            "hashes": [ -                "sha256:5b327ac1320dc863dca72f4514ecc086f31186744b84a230374cc1fd776feae5", -                "sha256:67714da7f7bc052e064859c05c595155bd1ee9f69f76557e21f051443c20947a" -            ], -            "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", -            "version": "==20.9" -        },          "pamqp": {              "hashes": [                  "sha256:2f81b5c186f668a67f165193925b6bfd83db4363a6222f599517f29ecee60b02", @@ -590,31 +484,6 @@              "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'",              "version": "==2.20"          }, -        "pygments": { -            "hashes": [ -                "sha256:2656e1a6edcdabf4275f9a3640db59fd5de107d88e8663c5d4e9a0fa62f77f94", -                "sha256:534ef71d539ae97d4c3a4cf7d6f110f214b0e687e92f9cb9d2a3b0d3101289c8" -            ], -            "markers": "python_version >= '3.5'", -            "version": "==2.8.1" -        }, -        "pyparsing": { -            "hashes": [ -                "sha256:c203ec8783bf771a155b207279b9bccb8dea02d8f0c9e5f8ead507bc3246ecc1", -                "sha256:ef9d7589ef3c200abe66653d3f1ab1033c3c419ae9b9bdb1240a85b024efc88b" -            ], -            "markers": "python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3'", -            "version": "==2.4.7" -        }, -        "pyreadline": { -            "hashes": [ -                "sha256:4530592fc2e85b25b1a9f79664433da09237c1a270e4d78ea5aa3a2c7229e2d1", -                "sha256:65540c21bfe14405a3a77e4c085ecfce88724743a4ead47c66b84defcf82c32e", -                "sha256:9ce5fa65b8992dfa373bddc5b6e0864ead8f291c94fbfec05fbd5c836162e67b" -            ], -            "markers": "sys_platform == 'win32'", -            "version": "==2.1" -        },          "python-dateutil": {              "hashes": [                  "sha256:73ebfe9dbf22e832286dafa60473e4cd239f8592f699aa5adaf10050e6e1823c", @@ -631,13 +500,6 @@              "index": "pypi",              "version": "==1.0.0"          }, -        "pytz": { -            "hashes": [ -                "sha256:83a4a90894bf38e243cf052c8b58f381bfe9a7a483f6a9cab140bc7f702ac4da", -                "sha256:eb10ce3e7736052ed3623d49975ce333bcd712c7bb19a58b9e2089d4057d0798" -            ], -            "version": "==2021.1" -        },          "pyyaml": {              "hashes": [                  "sha256:08682f6b72c722394747bddaf0aa62277e02557c0fd1c42cb853016a38f8dedf", @@ -728,14 +590,6 @@              "index": "pypi",              "version": "==2021.4.4"          }, -        "requests": { -            "hashes": [ -                "sha256:27973dd4a904a4f13b263a19c866c13b92a39ed1c964655f025f3f8d3d75b804", -                "sha256:c210084e36a42ae6b9219e00e48287def368a26d03a048ddad7bfee44f75871e" -            ], -            "index": "pypi", -            "version": "==2.25.1" -        },          "sentry-sdk": {              "hashes": [                  "sha256:4ae8d1ced6c67f1c8ea51d82a16721c166c489b76876c9f2c202b8a50334b237", @@ -749,16 +603,9 @@                  "sha256:30639c035cdb23534cd4aa2dd52c3bf48f06e5f4a941509c8bafd8ce11080259",                  "sha256:8b74bedcbbbaca38ff6d7491d76f2b06b3592611af620f8426e82dddb04a5ced"              ], -            "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", +            "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2'",              "version": "==1.15.0"          }, -        "snowballstemmer": { -            "hashes": [ -                "sha256:b51b447bea85f9968c13b650126a888aabd4cb4463fca868ec596826325dedc2", -                "sha256:e997baa4f2e9139951b6f4c631bad912dfd3c792467e2f03d7239464af90e914" -            ], -            "version": "==2.1.0" -        },          "sortedcontainers": {              "hashes": [                  "sha256:37257a32add0a3ee490bb170b599e93095eed89a55da91fa9f48753ea12fd73f", @@ -774,62 +621,6 @@              "markers": "python_version >= '3.0'",              "version": "==2.2.1"          }, -        "sphinx": { -            "hashes": [ -                "sha256:b4c750d546ab6d7e05bdff6ac24db8ae3e8b8253a3569b754e445110a0a12b66", -                "sha256:fc312670b56cb54920d6cc2ced455a22a547910de10b3142276495ced49231cb" -            ], -            "index": "pypi", -            "version": "==2.4.4" -        }, -        "sphinxcontrib-applehelp": { -            "hashes": [ -                "sha256:806111e5e962be97c29ec4c1e7fe277bfd19e9652fb1a4392105b43e01af885a", -                "sha256:a072735ec80e7675e3f432fcae8610ecf509c5f1869d17e2eecff44389cdbc58" -            ], -            "markers": "python_version >= '3.5'", -            "version": "==1.0.2" -        }, -        "sphinxcontrib-devhelp": { -            "hashes": [ -                "sha256:8165223f9a335cc1af7ffe1ed31d2871f325254c0423bc0c4c7cd1c1e4734a2e", -                "sha256:ff7f1afa7b9642e7060379360a67e9c41e8f3121f2ce9164266f61b9f4b338e4" -            ], -            "markers": "python_version >= '3.5'", -            "version": "==1.0.2" -        }, -        "sphinxcontrib-htmlhelp": { -            "hashes": [ -                "sha256:3c0bc24a2c41e340ac37c85ced6dafc879ab485c095b1d65d2461ac2f7cca86f", -                "sha256:e8f5bb7e31b2dbb25b9cc435c8ab7a79787ebf7f906155729338f3156d93659b" -            ], -            "markers": "python_version >= '3.5'", -            "version": "==1.0.3" -        }, -        "sphinxcontrib-jsmath": { -            "hashes": [ -                "sha256:2ec2eaebfb78f3f2078e73666b1415417a116cc848b72e5172e596c871103178", -                "sha256:a9925e4a4587247ed2191a22df5f6970656cb8ca2bd6284309578f2153e0c4b8" -            ], -            "markers": "python_version >= '3.5'", -            "version": "==1.0.1" -        }, -        "sphinxcontrib-qthelp": { -            "hashes": [ -                "sha256:4c33767ee058b70dba89a6fc5c1892c0d57a54be67ddd3e7875a18d14cba5a72", -                "sha256:bd9fc24bcb748a8d51fd4ecaade681350aa63009a347a8c14e637895444dfab6" -            ], -            "markers": "python_version >= '3.5'", -            "version": "==1.0.3" -        }, -        "sphinxcontrib-serializinghtml": { -            "hashes": [ -                "sha256:eaa0eccc86e982a9b939b2b82d12cc5d013385ba5eadcc7e4fed23f4405f77bc", -                "sha256:f242a81d423f59617a8e5cf16f5d4d74e28ee9a66f9e5b637a18082991db5a9a" -            ], -            "markers": "python_version >= '3.5'", -            "version": "==1.1.4" -        },          "statsd": {              "hashes": [                  "sha256:c610fb80347fca0ef62666d241bce64184bd7cc1efe582f9690e045c25535eaa", @@ -1103,11 +894,11 @@          },          "idna": {              "hashes": [ -                "sha256:b307872f855b18632ce0c21c5e45be78c0ea7ae4c15c828c20788b26921eb3f6", -                "sha256:b97d804b1e9b523befed77c48dacec60e6dcb0b5391d57af6a65a312a90648c0" +                "sha256:5205d03e7bcbb919cc9c19885f9920d622ca52448306f2377daede5cf3faac16", +                "sha256:c5b02147e01ea9920e6b0a3f1f7bb833612d507592c837a6c49552768f4054e1"              ], -            "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", -            "version": "==2.10" +            "markers": "python_version >= '3.4'", +            "version": "==3.1"          },          "mccabe": {              "hashes": [ @@ -1203,7 +994,7 @@                  "sha256:27973dd4a904a4f13b263a19c866c13b92a39ed1c964655f025f3f8d3d75b804",                  "sha256:c210084e36a42ae6b9219e00e48287def368a26d03a048ddad7bfee44f75871e"              ], -            "index": "pypi", +            "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'",              "version": "==2.25.1"          },          "six": { @@ -1211,7 +1002,7 @@                  "sha256:30639c035cdb23534cd4aa2dd52c3bf48f06e5f4a941509c8bafd8ce11080259",                  "sha256:8b74bedcbbbaca38ff6d7491d76f2b06b3592611af620f8426e82dddb04a5ced"              ], -            "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", +            "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2'",              "version": "==1.15.0"          },          "snowballstemmer": { @@ -1226,7 +1017,7 @@                  "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b",                  "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"              ], -            "markers": "python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3'", +            "markers": "python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2'",              "version": "==0.10.2"          },          "urllib3": { diff --git a/bot/constants.py b/bot/constants.py index 6d14bbb3a..7b2a38079 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -175,13 +175,14 @@ class YAMLGetter(type):              if cls.subsection is not None:                  return _CONFIG_YAML[cls.section][cls.subsection][name]              return _CONFIG_YAML[cls.section][name] -        except KeyError: +        except KeyError as e:              dotted_path = '.'.join(                  (cls.section, cls.subsection, name)                  if cls.subsection is not None else (cls.section, name)              ) -            log.critical(f"Tried accessing configuration variable at `{dotted_path}`, but it could not be found.") -            raise +            # Only an INFO log since this can be caught through `hasattr` or `getattr`. +            log.info(f"Tried accessing configuration variable at `{dotted_path}`, but it could not be found.") +            raise AttributeError(repr(name)) from e      def __getitem__(cls, name):          return cls.__getattr__(name) @@ -199,6 +200,7 @@ class Bot(metaclass=YAMLGetter):      prefix: str      sentry_dsn: Optional[str]      token: str +    trace_loggers: Optional[str]  class Redis(metaclass=YAMLGetter): @@ -279,6 +281,8 @@ class Emojis(metaclass=YAMLGetter):      badge_partner: str      badge_staff: str      badge_verified_bot_developer: str +    verified_bot: str +    bot: str      defcon_shutdown: str  # noqa: E704      defcon_unshutdown: str  # noqa: E704 @@ -491,6 +495,7 @@ class Roles(metaclass=YAMLGetter):      domain_leads: int      helpers: int      moderators: int +    mod_team: int      owners: int      project_leads: int diff --git a/bot/converters.py b/bot/converters.py index 67525cd4d..3bf05cfb3 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -15,6 +15,7 @@ from discord.utils import DISCORD_EPOCH, snowflake_time  from bot.api import ResponseCodeError  from bot.constants import URLs +from bot.exts.info.doc import _inventory_parser  from bot.utils.regex import INVITE_RE  from bot.utils.time import parse_duration_string @@ -127,22 +128,20 @@ class ValidFilterListType(Converter):          return list_type -class ValidPythonIdentifier(Converter): +class PackageName(Converter):      """ -    A converter that checks whether the given string is a valid Python identifier. +    A converter that checks whether the given string is a valid package name. -    This is used to have package names that correspond to how you would use the package in your -    code, e.g. `import package`. - -    Raises `BadArgument` if the argument is not a valid Python identifier, and simply passes through -    the given argument otherwise. +    Package names are used for stats and are restricted to the a-z and _ characters.      """ -    @staticmethod -    async def convert(ctx: Context, argument: str) -> str: -        """Checks whether the given string is a valid Python identifier.""" -        if not argument.isidentifier(): -            raise BadArgument(f"`{argument}` is not a valid Python identifier") +    PACKAGE_NAME_RE = re.compile(r"[^a-z0-9_]") + +    @classmethod +    async def convert(cls, ctx: Context, argument: str) -> str: +        """Checks whether the given string is a valid package name.""" +        if cls.PACKAGE_NAME_RE.search(argument): +            raise BadArgument("The provided package name is not valid; please only use the _, 0-9, and a-z characters.")          return argument @@ -178,6 +177,27 @@ class ValidURL(Converter):          return url +class Inventory(Converter): +    """ +    Represents an Intersphinx inventory URL. + +    This converter checks whether intersphinx accepts the given inventory URL, and raises +    `BadArgument` if that is not the case or if the url is unreachable. + +    Otherwise, it returns the url and the fetched inventory dict in a tuple. +    """ + +    @staticmethod +    async def convert(ctx: Context, url: str) -> t.Tuple[str, _inventory_parser.InventoryDict]: +        """Convert url to Intersphinx inventory URL.""" +        await ctx.trigger_typing() +        if (inventory := await _inventory_parser.fetch_inventory(url)) is None: +            raise BadArgument( +                f"Failed to fetch inventory file after {_inventory_parser.FAILED_REQUEST_ATTEMPTS} attempts." +            ) +        return url, inventory + +  class Snowflake(IDConverter):      """      Converts to an int if the argument is a valid Discord snowflake. diff --git a/bot/decorators.py b/bot/decorators.py index 0b50cc365..e971a5bd3 100644 --- a/bot/decorators.py +++ b/bot/decorators.py @@ -1,9 +1,9 @@  import asyncio  import functools  import logging +import types  import typing as t  from contextlib import suppress -from functools import wraps  from discord import Member, NotFound  from discord.ext import commands @@ -11,7 +11,8 @@ from discord.ext.commands import Cog, Context  from bot.constants import Channels, DEBUG_MODE, RedirectOutput  from bot.utils import function -from bot.utils.checks import in_whitelist_check +from bot.utils.checks import ContextCheckFailure, in_whitelist_check +from bot.utils.function import command_wraps  log = logging.getLogger(__name__) @@ -44,6 +45,49 @@ def in_whitelist(      return commands.check(predicate) +class NotInBlacklistCheckFailure(ContextCheckFailure): +    """Raised when the 'not_in_blacklist' check fails.""" + + +def not_in_blacklist( +    *, +    channels: t.Container[int] = (), +    categories: t.Container[int] = (), +    roles: t.Container[int] = (), +    override_roles: t.Container[int] = (), +    redirect: t.Optional[int] = Channels.bot_commands, +    fail_silently: bool = False, +) -> t.Callable: +    """ +    Check if a command was not issued in a blacklisted context. + +    The blacklists that can be provided are: + +    - `channels`: a container with channel ids for blacklisted channels +    - `categories`: a container with category ids for blacklisted categories +    - `roles`: a container with role ids for blacklisted roles + +    If the command was invoked in a context that was blacklisted, the member is either +    redirected to the `redirect` channel that was passed (default: #bot-commands) or simply +    told that they're not allowed to use this particular command (if `None` was passed). + +    The blacklist can be overridden through the roles specified in `override_roles`. +    """ +    def predicate(ctx: Context) -> bool: +        """Check if command was issued in a blacklisted context.""" +        not_blacklisted = not in_whitelist_check(ctx, channels, categories, roles, fail_silently=True) +        overridden = in_whitelist_check(ctx, roles=override_roles, fail_silently=True) + +        success = not_blacklisted or overridden + +        if not success and not fail_silently: +            raise NotInBlacklistCheckFailure(redirect) + +        return success + +    return commands.check(predicate) + +  def has_no_roles(*roles: t.Union[str, int]) -> t.Callable:      """      Returns True if the user does not have any of the roles specified. @@ -71,8 +115,8 @@ def redirect_output(destination_channel: int, bypass_roles: t.Container[int] = N      This decorator must go before (below) the `command` decorator.      """ -    def wrap(func: t.Callable) -> t.Callable: -        @wraps(func) +    def wrap(func: types.FunctionType) -> types.FunctionType: +        @command_wraps(func)          async def inner(self: Cog, ctx: Context, *args, **kwargs) -> None:              if ctx.channel.id == destination_channel:                  log.trace(f"Command {ctx.command.name} was invoked in destination_channel, not redirecting") @@ -106,7 +150,6 @@ def redirect_output(destination_channel: int, bypass_roles: t.Container[int] = N                  with suppress(NotFound):                      await ctx.message.delete()                      log.trace("Redirect output: Deleted invocation message") -          return inner      return wrap @@ -123,8 +166,8 @@ def respect_role_hierarchy(member_arg: function.Argument) -> t.Callable:      This decorator must go before (below) the `command` decorator.      """ -    def decorator(func: t.Callable) -> t.Callable: -        @wraps(func) +    def decorator(func: types.FunctionType) -> types.FunctionType: +        @command_wraps(func)          async def wrapper(*args, **kwargs) -> None:              log.trace(f"{func.__name__}: respect role hierarchy decorator called") diff --git a/bot/exts/backend/branding/_cog.py b/bot/exts/backend/branding/_cog.py index 0a4ddcc88..47c379a34 100644 --- a/bot/exts/backend/branding/_cog.py +++ b/bot/exts/backend/branding/_cog.py @@ -3,12 +3,13 @@ import contextlib  import logging  import random  import typing as t -from datetime import datetime, time, timedelta +from datetime import timedelta  from enum import Enum  from operator import attrgetter  import async_timeout  import discord +from arrow import Arrow  from async_rediscache import RedisCache  from discord.ext import commands, tasks @@ -57,6 +58,8 @@ def extract_event_duration(event: Event) -> str:      Extract a human-readable, year-agnostic duration string from `event`.      In the case that `event` is a fallback event, resolves to 'Fallback'. + +    For 1-day events, only the single date is shown, instead of a period.      """      if event.meta.is_fallback:          return "Fallback" @@ -65,6 +68,9 @@ def extract_event_duration(event: Event) -> str:      start_date = event.meta.start_date.strftime(fmt)      end_date = event.meta.end_date.strftime(fmt) +    if start_date == end_date: +        return start_date +      return f"{start_date} - {end_date}" @@ -208,7 +214,7 @@ class Branding(commands.Cog):          if success:              await self.cache_icons.increment(next_icon)  # Push the icon into the next iteration. -            timestamp = datetime.utcnow().timestamp() +            timestamp = Arrow.utcnow().timestamp()              await self.cache_information.set("last_rotation_timestamp", timestamp)          return success @@ -229,8 +235,8 @@ class Branding(commands.Cog):              await self.rotate_icons()              return -        last_rotation = datetime.fromtimestamp(last_rotation_timestamp) -        difference = (datetime.utcnow() - last_rotation) + timedelta(minutes=5) +        last_rotation = Arrow.utcfromtimestamp(last_rotation_timestamp) +        difference = (Arrow.utcnow() - last_rotation) + timedelta(minutes=5)          log.trace(f"Icons last rotated at {last_rotation} (difference: {difference}).") @@ -485,11 +491,11 @@ class Branding(commands.Cog):          await self.daemon_loop()          log.trace("Daemon before: calculating time to sleep before loop begins.") -        now = datetime.utcnow() +        now = Arrow.utcnow()          # The actual midnight moment is offset into the future to prevent issues with imprecise sleep. -        tomorrow = now + timedelta(days=1) -        midnight = datetime.combine(tomorrow, time(minute=1)) +        tomorrow = now.shift(days=1) +        midnight = tomorrow.replace(hour=0, minute=1, second=0, microsecond=0)          sleep_secs = (midnight - now).total_seconds()          log.trace(f"Daemon before: sleeping {sleep_secs} seconds before next-up midnight: {midnight}.") diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py index 76ab7dfc2..d8de177f5 100644 --- a/bot/exts/backend/error_handler.py +++ b/bot/exts/backend/error_handler.py @@ -1,4 +1,3 @@ -import contextlib  import difflib  import logging  import typing as t @@ -12,7 +11,7 @@ from bot.bot import Bot  from bot.constants import Colours, Icons, MODERATION_ROLES  from bot.converters import TagNameConverter  from bot.errors import InvalidInfractedUser, LockedResourceError -from bot.utils.checks import InWhitelistCheckFailure +from bot.utils.checks import ContextCheckFailure  log = logging.getLogger(__name__) @@ -60,7 +59,7 @@ class ErrorHandler(Cog):              log.trace(f"Command {command} had its error already handled locally; ignoring.")              return -        if isinstance(e, errors.CommandNotFound) and not hasattr(ctx, "invoked_from_error_handler"): +        if isinstance(e, errors.CommandNotFound) and not getattr(ctx, "invoked_from_error_handler", False):              if await self.try_silence(ctx):                  return              # Try to look for a tag with the command's name @@ -162,9 +161,8 @@ class ErrorHandler(Cog):                  f"and the fallback tag failed validation in TagNameConverter."              )          else: -            with contextlib.suppress(ResponseCodeError): -                if await ctx.invoke(tags_get_command, tag_name=tag_name): -                    return +            if await ctx.invoke(tags_get_command, tag_name=tag_name): +                return          if not any(role.id in MODERATION_ROLES for role in ctx.author.roles):              await self.send_command_suggestion(ctx, ctx.invoked_with) @@ -214,32 +212,30 @@ class ErrorHandler(Cog):          * ArgumentParsingError: send an error message          * Other: send an error message and the help command          """ -        prepared_help_command = self.get_help_command(ctx) -          if isinstance(e, errors.MissingRequiredArgument):              embed = self._get_error_embed("Missing required argument", e.param.name)              await ctx.send(embed=embed) -            await prepared_help_command +            await self.get_help_command(ctx)              self.bot.stats.incr("errors.missing_required_argument")          elif isinstance(e, errors.TooManyArguments):              embed = self._get_error_embed("Too many arguments", str(e))              await ctx.send(embed=embed) -            await prepared_help_command +            await self.get_help_command(ctx)              self.bot.stats.incr("errors.too_many_arguments")          elif isinstance(e, errors.BadArgument):              embed = self._get_error_embed("Bad argument", str(e))              await ctx.send(embed=embed) -            await prepared_help_command +            await self.get_help_command(ctx)              self.bot.stats.incr("errors.bad_argument")          elif isinstance(e, errors.BadUnionArgument):              embed = self._get_error_embed("Bad argument", f"{e}\n{e.errors[-1]}")              await ctx.send(embed=embed) -            await prepared_help_command +            await self.get_help_command(ctx)              self.bot.stats.incr("errors.bad_union_argument")          elif isinstance(e, errors.ArgumentParsingError):              embed = self._get_error_embed("Argument parsing error", str(e))              await ctx.send(embed=embed) -            prepared_help_command.close() +            self.get_help_command(ctx).close()              self.bot.stats.incr("errors.argument_parsing_error")          else:              embed = self._get_error_embed( @@ -247,7 +243,7 @@ class ErrorHandler(Cog):                  "Something about your input seems off. Check the arguments and try again."              )              await ctx.send(embed=embed) -            await prepared_help_command +            await self.get_help_command(ctx)              self.bot.stats.incr("errors.other_user_input_error")      @staticmethod @@ -274,7 +270,7 @@ class ErrorHandler(Cog):              await ctx.send(                  "Sorry, it looks like I don't have the permissions or roles I need to do that."              ) -        elif isinstance(e, (InWhitelistCheckFailure, errors.NoPrivateMessage)): +        elif isinstance(e, (ContextCheckFailure, errors.NoPrivateMessage)):              ctx.bot.stats.incr("errors.wrong_channel_or_dm_error")              await ctx.send(e) diff --git a/bot/exts/filters/antispam.py b/bot/exts/filters/antispam.py index af8528a68..7555e25a2 100644 --- a/bot/exts/filters/antispam.py +++ b/bot/exts/filters/antispam.py @@ -3,7 +3,7 @@ import logging  from collections.abc import Mapping  from dataclasses import dataclass, field  from datetime import datetime, timedelta -from operator import itemgetter +from operator import attrgetter, itemgetter  from typing import Dict, Iterable, List, Set  from discord import Colour, Member, Message, NotFound, Object, TextChannel @@ -18,6 +18,7 @@ from bot.constants import (  )  from bot.converters import Duration  from bot.exts.moderation.modlog import ModLog +from bot.utils import lock, scheduling  from bot.utils.messages import format_user, send_attachments @@ -114,7 +115,7 @@ class AntiSpam(Cog):          self.message_deletion_queue = dict() -        self.bot.loop.create_task(self.alert_on_validation_error()) +        self.bot.loop.create_task(self.alert_on_validation_error(), name="AntiSpam.alert_on_validation_error")      @property      def mod_log(self) -> ModLog: @@ -191,7 +192,10 @@ class AntiSpam(Cog):                  if channel.id not in self.message_deletion_queue:                      log.trace(f"Creating queue for channel `{channel.id}`")                      self.message_deletion_queue[message.channel.id] = DeletionContext(channel) -                    self.bot.loop.create_task(self._process_deletion_context(message.channel.id)) +                    scheduling.create_task( +                        self._process_deletion_context(message.channel.id), +                        name=f"AntiSpam._process_deletion_context({message.channel.id})" +                    )                  # Add the relevant of this trigger to the Deletion Context                  await self.message_deletion_queue[message.channel.id].add( @@ -201,16 +205,15 @@ class AntiSpam(Cog):                  )                  for member in members: - -                    # Fire it off as a background task to ensure -                    # that the sleep doesn't block further tasks -                    self.bot.loop.create_task( -                        self.punish(message, member, full_reason) +                    scheduling.create_task( +                        self.punish(message, member, full_reason), +                        name=f"AntiSpam.punish(message={message.id}, member={member.id}, rule={rule_name})"                      )                  await self.maybe_delete_messages(channel, relevant_messages)                  break +    @lock.lock_arg("antispam.punish", "member", attrgetter("id"))      async def punish(self, msg: Message, member: Member, reason: str) -> None:          """Punishes the given member for triggering an antispam rule."""          if not any(role.id == self.muted_role.id for role in member.roles): diff --git a/bot/exts/info/code_snippets.py b/bot/exts/info/code_snippets.py new file mode 100644 index 000000000..06885410b --- /dev/null +++ b/bot/exts/info/code_snippets.py @@ -0,0 +1,265 @@ +import logging +import re +import textwrap +from typing import Any +from urllib.parse import quote_plus + +from aiohttp import ClientResponseError +from discord import Message +from discord.ext.commands import Cog + +from bot.bot import Bot +from bot.constants import Channels +from bot.utils.messages import wait_for_deletion + +log = logging.getLogger(__name__) + +GITHUB_RE = re.compile( +    r'https://github\.com/(?P<repo>[a-zA-Z0-9-]+/[\w.-]+)/blob/' +    r'(?P<path>[^#>]+)(\?[^#>]+)?(#L(?P<start_line>\d+)([-~:]L(?P<end_line>\d+))?)' +) + +GITHUB_GIST_RE = re.compile( +    r'https://gist\.github\.com/([a-zA-Z0-9-]+)/(?P<gist_id>[a-zA-Z0-9]+)/*' +    r'(?P<revision>[a-zA-Z0-9]*)/*#file-(?P<file_path>[^#>]+?)(\?[^#>]+)?' +    r'(-L(?P<start_line>\d+)([-~:]L(?P<end_line>\d+))?)' +) + +GITHUB_HEADERS = {'Accept': 'application/vnd.github.v3.raw'} + +GITLAB_RE = re.compile( +    r'https://gitlab\.com/(?P<repo>[\w.-]+/[\w.-]+)/\-/blob/(?P<path>[^#>]+)' +    r'(\?[^#>]+)?(#L(?P<start_line>\d+)(-(?P<end_line>\d+))?)' +) + +BITBUCKET_RE = re.compile( +    r'https://bitbucket\.org/(?P<repo>[a-zA-Z0-9-]+/[\w.-]+)/src/(?P<ref>[0-9a-zA-Z]+)' +    r'/(?P<file_path>[^#>]+)(\?[^#>]+)?(#lines-(?P<start_line>\d+)(:(?P<end_line>\d+))?)' +) + + +class CodeSnippets(Cog): +    """ +    Cog that parses and sends code snippets to Discord. + +    Matches each message against a regex and prints the contents of all matched snippets. +    """ + +    async def _fetch_response(self, url: str, response_format: str, **kwargs) -> Any: +        """Makes http requests using aiohttp.""" +        async with self.bot.http_session.get(url, raise_for_status=True, **kwargs) as response: +            if response_format == 'text': +                return await response.text() +            elif response_format == 'json': +                return await response.json() + +    def _find_ref(self, path: str, refs: tuple) -> tuple: +        """Loops through all branches and tags to find the required ref.""" +        # Base case: there is no slash in the branch name +        ref, file_path = path.split('/', 1) +        # In case there are slashes in the branch name, we loop through all branches and tags +        for possible_ref in refs: +            if path.startswith(possible_ref['name'] + '/'): +                ref = possible_ref['name'] +                file_path = path[len(ref) + 1:] +                break +        return ref, file_path + +    async def _fetch_github_snippet( +        self, +        repo: str, +        path: str, +        start_line: str, +        end_line: str +    ) -> str: +        """Fetches a snippet from a GitHub repo.""" +        # Search the GitHub API for the specified branch +        branches = await self._fetch_response( +            f'https://api.github.com/repos/{repo}/branches', +            'json', +            headers=GITHUB_HEADERS +        ) +        tags = await self._fetch_response(f'https://api.github.com/repos/{repo}/tags', 'json', headers=GITHUB_HEADERS) +        refs = branches + tags +        ref, file_path = self._find_ref(path, refs) + +        file_contents = await self._fetch_response( +            f'https://api.github.com/repos/{repo}/contents/{file_path}?ref={ref}', +            'text', +            headers=GITHUB_HEADERS, +        ) +        return self._snippet_to_codeblock(file_contents, file_path, start_line, end_line) + +    async def _fetch_github_gist_snippet( +        self, +        gist_id: str, +        revision: str, +        file_path: str, +        start_line: str, +        end_line: str +    ) -> str: +        """Fetches a snippet from a GitHub gist.""" +        gist_json = await self._fetch_response( +            f'https://api.github.com/gists/{gist_id}{f"/{revision}" if len(revision) > 0 else ""}', +            'json', +            headers=GITHUB_HEADERS, +        ) + +        # Check each file in the gist for the specified file +        for gist_file in gist_json['files']: +            if file_path == gist_file.lower().replace('.', '-'): +                file_contents = await self._fetch_response( +                    gist_json['files'][gist_file]['raw_url'], +                    'text', +                ) +                return self._snippet_to_codeblock(file_contents, gist_file, start_line, end_line) +        return '' + +    async def _fetch_gitlab_snippet( +        self, +        repo: str, +        path: str, +        start_line: str, +        end_line: str +    ) -> str: +        """Fetches a snippet from a GitLab repo.""" +        enc_repo = quote_plus(repo) + +        # Searches the GitLab API for the specified branch +        branches = await self._fetch_response( +            f'https://gitlab.com/api/v4/projects/{enc_repo}/repository/branches', +            'json' +        ) +        tags = await self._fetch_response(f'https://gitlab.com/api/v4/projects/{enc_repo}/repository/tags', 'json') +        refs = branches + tags +        ref, file_path = self._find_ref(path, refs) +        enc_ref = quote_plus(ref) +        enc_file_path = quote_plus(file_path) + +        file_contents = await self._fetch_response( +            f'https://gitlab.com/api/v4/projects/{enc_repo}/repository/files/{enc_file_path}/raw?ref={enc_ref}', +            'text', +        ) +        return self._snippet_to_codeblock(file_contents, file_path, start_line, end_line) + +    async def _fetch_bitbucket_snippet( +        self, +        repo: str, +        ref: str, +        file_path: str, +        start_line: str, +        end_line: str +    ) -> str: +        """Fetches a snippet from a BitBucket repo.""" +        file_contents = await self._fetch_response( +            f'https://bitbucket.org/{quote_plus(repo)}/raw/{quote_plus(ref)}/{quote_plus(file_path)}', +            'text', +        ) +        return self._snippet_to_codeblock(file_contents, file_path, start_line, end_line) + +    def _snippet_to_codeblock(self, file_contents: str, file_path: str, start_line: str, end_line: str) -> str: +        """ +        Given the entire file contents and target lines, creates a code block. + +        First, we split the file contents into a list of lines and then keep and join only the required +        ones together. + +        We then dedent the lines to look nice, and replace all ` characters with `\u200b to prevent +        markdown injection. + +        Finally, we surround the code with ``` characters. +        """ +        # Parse start_line and end_line into integers +        if end_line is None: +            start_line = end_line = int(start_line) +        else: +            start_line = int(start_line) +            end_line = int(end_line) + +        split_file_contents = file_contents.splitlines() + +        # Make sure that the specified lines are in range +        if start_line > end_line: +            start_line, end_line = end_line, start_line +        if start_line > len(split_file_contents) or end_line < 1: +            return '' +        start_line = max(1, start_line) +        end_line = min(len(split_file_contents), end_line) + +        # Gets the code lines, dedents them, and inserts zero-width spaces to prevent Markdown injection +        required = '\n'.join(split_file_contents[start_line - 1:end_line]) +        required = textwrap.dedent(required).rstrip().replace('`', '`\u200b') + +        # Extracts the code language and checks whether it's a "valid" language +        language = file_path.split('/')[-1].split('.')[-1] +        trimmed_language = language.replace('-', '').replace('+', '').replace('_', '') +        is_valid_language = trimmed_language.isalnum() +        if not is_valid_language: +            language = '' + +        # Adds a label showing the file path to the snippet +        if start_line == end_line: +            ret = f'`{file_path}` line {start_line}\n' +        else: +            ret = f'`{file_path}` lines {start_line} to {end_line}\n' + +        if len(required) != 0: +            return f'{ret}```{language}\n{required}```' +        # Returns an empty codeblock if the snippet is empty +        return f'{ret}``` ```' + +    def __init__(self, bot: Bot): +        """Initializes the cog's bot.""" +        self.bot = bot + +        self.pattern_handlers = [ +            (GITHUB_RE, self._fetch_github_snippet), +            (GITHUB_GIST_RE, self._fetch_github_gist_snippet), +            (GITLAB_RE, self._fetch_gitlab_snippet), +            (BITBUCKET_RE, self._fetch_bitbucket_snippet) +        ] + +    @Cog.listener() +    async def on_message(self, message: Message) -> None: +        """Checks if the message has a snippet link, removes the embed, then sends the snippet contents.""" +        if not message.author.bot: +            all_snippets = [] + +            for pattern, handler in self.pattern_handlers: +                for match in pattern.finditer(message.content): +                    try: +                        snippet = await handler(**match.groupdict()) +                        all_snippets.append((match.start(), snippet)) +                    except ClientResponseError as error: +                        error_message = error.message  # noqa: B306 +                        log.log( +                            logging.DEBUG if error.status == 404 else logging.ERROR, +                            f'Failed to fetch code snippet from {match[0]!r}: {error.status} ' +                            f'{error_message} for GET {error.request_info.real_url.human_repr()}' +                        ) + +            # Sorts the list of snippets by their match index and joins them into a single message +            message_to_send = '\n'.join(map(lambda x: x[1], sorted(all_snippets))) + +            if 0 < len(message_to_send) <= 2000 and message_to_send.count('\n') <= 15: +                await message.edit(suppress=True) +                if len(message_to_send) > 1000 and message.channel.id != Channels.bot_commands: +                    # Redirects to #bot-commands if the snippet contents are too long +                    await self.bot.wait_until_guild_available() +                    await message.channel.send(('The snippet you tried to send was too long. Please ' +                                                f'see <#{Channels.bot_commands}> for the full snippet.')) +                    bot_commands_channel = self.bot.get_channel(Channels.bot_commands) +                    await wait_for_deletion( +                        await bot_commands_channel.send(message_to_send), +                        (message.author.id,) +                    ) +                else: +                    await wait_for_deletion( +                        await message.channel.send(message_to_send), +                        (message.author.id,) +                    ) + + +def setup(bot: Bot) -> None: +    """Load the CodeSnippets cog.""" +    bot.add_cog(CodeSnippets(bot)) diff --git a/bot/exts/info/doc.py b/bot/exts/info/doc.py deleted file mode 100644 index 9b5bd6504..000000000 --- a/bot/exts/info/doc.py +++ /dev/null @@ -1,485 +0,0 @@ -import asyncio -import functools -import logging -import re -import textwrap -from contextlib import suppress -from types import SimpleNamespace -from typing import Optional, Tuple - -import discord -from bs4 import BeautifulSoup -from bs4.element import PageElement, Tag -from discord.errors import NotFound -from discord.ext import commands -from markdownify import MarkdownConverter -from requests import ConnectTimeout, ConnectionError, HTTPError -from sphinx.ext import intersphinx -from urllib3.exceptions import ProtocolError - -from bot.bot import Bot -from bot.constants import MODERATION_ROLES, RedirectOutput -from bot.converters import ValidPythonIdentifier, ValidURL -from bot.pagination import LinePaginator -from bot.utils.cache import AsyncCache -from bot.utils.messages import wait_for_deletion - - -log = logging.getLogger(__name__) -logging.getLogger('urllib3').setLevel(logging.WARNING) - -# Since Intersphinx is intended to be used with Sphinx, -# we need to mock its configuration. -SPHINX_MOCK_APP = SimpleNamespace( -    config=SimpleNamespace( -        intersphinx_timeout=3, -        tls_verify=True, -        user_agent="python3:python-discord/bot:1.0.0" -    ) -) - -NO_OVERRIDE_GROUPS = ( -    "2to3fixer", -    "token", -    "label", -    "pdbcommand", -    "term", -) -NO_OVERRIDE_PACKAGES = ( -    "python", -) - -SEARCH_END_TAG_ATTRS = ( -    "data", -    "function", -    "class", -    "exception", -    "seealso", -    "section", -    "rubric", -    "sphinxsidebar", -) -UNWANTED_SIGNATURE_SYMBOLS_RE = re.compile(r"\[source]|\\\\|¶") -WHITESPACE_AFTER_NEWLINES_RE = re.compile(r"(?<=\n\n)(\s+)") - -FAILED_REQUEST_RETRY_AMOUNT = 3 -NOT_FOUND_DELETE_DELAY = RedirectOutput.delete_delay - -symbol_cache = AsyncCache() - - -class DocMarkdownConverter(MarkdownConverter): -    """Subclass markdownify's MarkdownCoverter to provide custom conversion methods.""" - -    def convert_code(self, el: PageElement, text: str) -> str: -        """Undo `markdownify`s underscore escaping.""" -        return f"`{text}`".replace('\\', '') - -    def convert_pre(self, el: PageElement, text: str) -> str: -        """Wrap any codeblocks in `py` for syntax highlighting.""" -        code = ''.join(el.strings) -        return f"```py\n{code}```" - - -def markdownify(html: str) -> DocMarkdownConverter: -    """Create a DocMarkdownConverter object from the input html.""" -    return DocMarkdownConverter(bullets='•').convert(html) - - -class InventoryURL(commands.Converter): -    """ -    Represents an Intersphinx inventory URL. - -    This converter checks whether intersphinx accepts the given inventory URL, and raises -    `BadArgument` if that is not the case. - -    Otherwise, it simply passes through the given URL. -    """ - -    @staticmethod -    async def convert(ctx: commands.Context, url: str) -> str: -        """Convert url to Intersphinx inventory URL.""" -        try: -            intersphinx.fetch_inventory(SPHINX_MOCK_APP, '', url) -        except AttributeError: -            raise commands.BadArgument(f"Failed to fetch Intersphinx inventory from URL `{url}`.") -        except ConnectionError: -            if url.startswith('https'): -                raise commands.BadArgument( -                    f"Cannot establish a connection to `{url}`. Does it support HTTPS?" -                ) -            raise commands.BadArgument(f"Cannot connect to host with URL `{url}`.") -        except ValueError: -            raise commands.BadArgument( -                f"Failed to read Intersphinx inventory from URL `{url}`. " -                "Are you sure that it's a valid inventory file?" -            ) -        return url - - -class Doc(commands.Cog): -    """A set of commands for querying & displaying documentation.""" - -    def __init__(self, bot: Bot): -        self.base_urls = {} -        self.bot = bot -        self.inventories = {} -        self.renamed_symbols = set() - -        self.bot.loop.create_task(self.init_refresh_inventory()) - -    async def init_refresh_inventory(self) -> None: -        """Refresh documentation inventory on cog initialization.""" -        await self.bot.wait_until_guild_available() -        await self.refresh_inventory() - -    async def update_single( -        self, package_name: str, base_url: str, inventory_url: str -    ) -> None: -        """ -        Rebuild the inventory for a single package. - -        Where: -            * `package_name` is the package name to use, appears in the log -            * `base_url` is the root documentation URL for the specified package, used to build -                absolute paths that link to specific symbols -            * `inventory_url` is the absolute URL to the intersphinx inventory, fetched by running -                `intersphinx.fetch_inventory` in an executor on the bot's event loop -        """ -        self.base_urls[package_name] = base_url - -        package = await self._fetch_inventory(inventory_url) -        if not package: -            return None - -        for group, value in package.items(): -            for symbol, (package_name, _version, relative_doc_url, _) in value.items(): -                absolute_doc_url = base_url + relative_doc_url - -                if symbol in self.inventories: -                    group_name = group.split(":")[1] -                    symbol_base_url = self.inventories[symbol].split("/", 3)[2] -                    if ( -                        group_name in NO_OVERRIDE_GROUPS -                        or any(package in symbol_base_url for package in NO_OVERRIDE_PACKAGES) -                    ): - -                        symbol = f"{group_name}.{symbol}" -                        # If renamed `symbol` already exists, add library name in front to differentiate between them. -                        if symbol in self.renamed_symbols: -                            # Split `package_name` because of packages like Pillow that have spaces in them. -                            symbol = f"{package_name.split()[0]}.{symbol}" - -                        self.inventories[symbol] = absolute_doc_url -                        self.renamed_symbols.add(symbol) -                        continue - -                self.inventories[symbol] = absolute_doc_url - -        log.trace(f"Fetched inventory for {package_name}.") - -    async def refresh_inventory(self) -> None: -        """Refresh internal documentation inventory.""" -        log.debug("Refreshing documentation inventory...") - -        # Clear the old base URLS and inventories to ensure -        # that we start from a fresh local dataset. -        # Also, reset the cache used for fetching documentation. -        self.base_urls.clear() -        self.inventories.clear() -        self.renamed_symbols.clear() -        symbol_cache.clear() - -        # Run all coroutines concurrently - since each of them performs a HTTP -        # request, this speeds up fetching the inventory data heavily. -        coros = [ -            self.update_single( -                package["package"], package["base_url"], package["inventory_url"] -            ) for package in await self.bot.api_client.get('bot/documentation-links') -        ] -        await asyncio.gather(*coros) - -    async def get_symbol_html(self, symbol: str) -> Optional[Tuple[list, str]]: -        """ -        Given a Python symbol, return its signature and description. - -        The first tuple element is the signature of the given symbol as a markup-free string, and -        the second tuple element is the description of the given symbol with HTML markup included. - -        If the given symbol is a module, returns a tuple `(None, str)` -        else if the symbol could not be found, returns `None`. -        """ -        url = self.inventories.get(symbol) -        if url is None: -            return None - -        async with self.bot.http_session.get(url) as response: -            html = await response.text(encoding='utf-8') - -        # Find the signature header and parse the relevant parts. -        symbol_id = url.split('#')[-1] -        soup = BeautifulSoup(html, 'lxml') -        symbol_heading = soup.find(id=symbol_id) -        search_html = str(soup) - -        if symbol_heading is None: -            return None - -        if symbol_id == f"module-{symbol}": -            # Get page content from the module headerlink to the -            # first tag that has its class in `SEARCH_END_TAG_ATTRS` -            start_tag = symbol_heading.find("a", attrs={"class": "headerlink"}) -            if start_tag is None: -                return [], "" - -            end_tag = start_tag.find_next(self._match_end_tag) -            if end_tag is None: -                return [], "" - -            description_start_index = search_html.find(str(start_tag.parent)) + len(str(start_tag.parent)) -            description_end_index = search_html.find(str(end_tag)) -            description = search_html[description_start_index:description_end_index] -            signatures = None - -        else: -            signatures = [] -            description = str(symbol_heading.find_next_sibling("dd")) -            description_pos = search_html.find(description) -            # Get text of up to 3 signatures, remove unwanted symbols -            for element in [symbol_heading] + symbol_heading.find_next_siblings("dt", limit=2): -                signature = UNWANTED_SIGNATURE_SYMBOLS_RE.sub("", element.text) -                if signature and search_html.find(str(element)) < description_pos: -                    signatures.append(signature) - -        return signatures, description.replace('¶', '') - -    @symbol_cache(arg_offset=1) -    async def get_symbol_embed(self, symbol: str) -> Optional[discord.Embed]: -        """ -        Attempt to scrape and fetch the data for the given `symbol`, and build an embed from its contents. - -        If the symbol is known, an Embed with documentation about it is returned. -        """ -        scraped_html = await self.get_symbol_html(symbol) -        if scraped_html is None: -            return None - -        signatures = scraped_html[0] -        permalink = self.inventories[symbol] -        description = markdownify(scraped_html[1]) - -        # Truncate the description of the embed to the last occurrence -        # of a double newline (interpreted as a paragraph) before index 1000. -        if len(description) > 1000: -            shortened = description[:1000] -            description_cutoff = shortened.rfind('\n\n', 100) -            if description_cutoff == -1: -                # Search the shortened version for cutoff points in decreasing desirability, -                # cutoff at 1000 if none are found. -                for string in (". ", ", ", ",", " "): -                    description_cutoff = shortened.rfind(string) -                    if description_cutoff != -1: -                        break -                else: -                    description_cutoff = 1000 -            description = description[:description_cutoff] - -            # If there is an incomplete code block, cut it out -            if description.count("```") % 2: -                codeblock_start = description.rfind('```py') -                description = description[:codeblock_start].rstrip() -            description += f"... [read more]({permalink})" - -        description = WHITESPACE_AFTER_NEWLINES_RE.sub('', description) -        if signatures is None: -            # If symbol is a module, don't show signature. -            embed_description = description - -        elif not signatures: -            # It's some "meta-page", for example: -            # https://docs.djangoproject.com/en/dev/ref/views/#module-django.views -            embed_description = "This appears to be a generic page not tied to a specific symbol." - -        else: -            embed_description = "".join(f"```py\n{textwrap.shorten(signature, 500)}```" for signature in signatures) -            embed_description += f"\n{description}" - -        embed = discord.Embed( -            title=f'`{symbol}`', -            url=permalink, -            description=embed_description -        ) -        # Show all symbols with the same name that were renamed in the footer. -        embed.set_footer( -            text=", ".join(renamed for renamed in self.renamed_symbols - {symbol} if renamed.endswith(f".{symbol}")) -        ) -        return embed - -    @commands.group(name='docs', aliases=('doc', 'd'), invoke_without_command=True) -    async def docs_group(self, ctx: commands.Context, symbol: commands.clean_content = None) -> None: -        """Lookup documentation for Python symbols.""" -        await self.get_command(ctx, symbol) - -    @docs_group.command(name='get', aliases=('g',)) -    async def get_command(self, ctx: commands.Context, symbol: commands.clean_content = None) -> None: -        """ -        Return a documentation embed for a given symbol. - -        If no symbol is given, return a list of all available inventories. - -        Examples: -            !docs -            !docs aiohttp -            !docs aiohttp.ClientSession -            !docs get aiohttp.ClientSession -        """ -        if symbol is None: -            inventory_embed = discord.Embed( -                title=f"All inventories (`{len(self.base_urls)}` total)", -                colour=discord.Colour.blue() -            ) - -            lines = sorted(f"• [`{name}`]({url})" for name, url in self.base_urls.items()) -            if self.base_urls: -                await LinePaginator.paginate(lines, ctx, inventory_embed, max_size=400, empty=False) - -            else: -                inventory_embed.description = "Hmmm, seems like there's nothing here yet." -                await ctx.send(embed=inventory_embed) - -        else: -            # Fetching documentation for a symbol (at least for the first time, since -            # caching is used) takes quite some time, so let's send typing to indicate -            # that we got the command, but are still working on it. -            async with ctx.typing(): -                doc_embed = await self.get_symbol_embed(symbol) - -            if doc_embed is None: -                error_embed = discord.Embed( -                    description=f"Sorry, I could not find any documentation for `{symbol}`.", -                    colour=discord.Colour.red() -                ) -                error_message = await ctx.send(embed=error_embed) -                with suppress(NotFound): -                    await error_message.delete(delay=NOT_FOUND_DELETE_DELAY) -                    await ctx.message.delete(delay=NOT_FOUND_DELETE_DELAY) -            else: -                msg = await ctx.send(embed=doc_embed) -                await wait_for_deletion(msg, (ctx.author.id,)) - -    @docs_group.command(name='set', aliases=('s',)) -    @commands.has_any_role(*MODERATION_ROLES) -    async def set_command( -        self, ctx: commands.Context, package_name: ValidPythonIdentifier, -        base_url: ValidURL, inventory_url: InventoryURL -    ) -> None: -        """ -        Adds a new documentation metadata object to the site's database. - -        The database will update the object, should an existing item with the specified `package_name` already exist. - -        Example: -            !docs set \ -                    python \ -                    https://docs.python.org/3/ \ -                    https://docs.python.org/3/objects.inv -        """ -        body = { -            'package': package_name, -            'base_url': base_url, -            'inventory_url': inventory_url -        } -        await self.bot.api_client.post('bot/documentation-links', json=body) - -        log.info( -            f"User @{ctx.author} ({ctx.author.id}) added a new documentation package:\n" -            f"Package name: {package_name}\n" -            f"Base url: {base_url}\n" -            f"Inventory URL: {inventory_url}" -        ) - -        # Rebuilding the inventory can take some time, so lets send out a -        # typing event to show that the Bot is still working. -        async with ctx.typing(): -            await self.refresh_inventory() -        await ctx.send(f"Added package `{package_name}` to database and refreshed inventory.") - -    @docs_group.command(name='delete', aliases=('remove', 'rm', 'd')) -    @commands.has_any_role(*MODERATION_ROLES) -    async def delete_command(self, ctx: commands.Context, package_name: ValidPythonIdentifier) -> None: -        """ -        Removes the specified package from the database. - -        Examples: -            !docs delete aiohttp -        """ -        await self.bot.api_client.delete(f'bot/documentation-links/{package_name}') - -        async with ctx.typing(): -            # Rebuild the inventory to ensure that everything -            # that was from this package is properly deleted. -            await self.refresh_inventory() -        await ctx.send(f"Successfully deleted `{package_name}` and refreshed inventory.") - -    @docs_group.command(name="refresh", aliases=("rfsh", "r")) -    @commands.has_any_role(*MODERATION_ROLES) -    async def refresh_command(self, ctx: commands.Context) -> None: -        """Refresh inventories and send differences to channel.""" -        old_inventories = set(self.base_urls) -        with ctx.typing(): -            await self.refresh_inventory() -        # Get differences of added and removed inventories -        added = ', '.join(inv for inv in self.base_urls if inv not in old_inventories) -        if added: -            added = f"+ {added}" - -        removed = ', '.join(inv for inv in old_inventories if inv not in self.base_urls) -        if removed: -            removed = f"- {removed}" - -        embed = discord.Embed( -            title="Inventories refreshed", -            description=f"```diff\n{added}\n{removed}```" if added or removed else "" -        ) -        await ctx.send(embed=embed) - -    async def _fetch_inventory(self, inventory_url: str) -> Optional[dict]: -        """Get and return inventory from `inventory_url`. If fetching fails, return None.""" -        fetch_func = functools.partial(intersphinx.fetch_inventory, SPHINX_MOCK_APP, '', inventory_url) -        for retry in range(1, FAILED_REQUEST_RETRY_AMOUNT+1): -            try: -                package = await self.bot.loop.run_in_executor(None, fetch_func) -            except ConnectTimeout: -                log.error( -                    f"Fetching of inventory {inventory_url} timed out," -                    f" trying again. ({retry}/{FAILED_REQUEST_RETRY_AMOUNT})" -                ) -            except ProtocolError: -                log.error( -                    f"Connection lost while fetching inventory {inventory_url}," -                    f" trying again. ({retry}/{FAILED_REQUEST_RETRY_AMOUNT})" -                ) -            except HTTPError as e: -                log.error(f"Fetching of inventory {inventory_url} failed with status code {e.response.status_code}.") -                return None -            except ConnectionError: -                log.error(f"Couldn't establish connection to inventory {inventory_url}.") -                return None -            else: -                return package -        log.error(f"Fetching of inventory {inventory_url} failed.") -        return None - -    @staticmethod -    def _match_end_tag(tag: Tag) -> bool: -        """Matches `tag` if its class value is in `SEARCH_END_TAG_ATTRS` or the tag is table.""" -        for attr in SEARCH_END_TAG_ATTRS: -            if attr in tag.get("class", ()): -                return True - -        return tag.name == "table" - - -def setup(bot: Bot) -> None: -    """Load the Doc cog.""" -    bot.add_cog(Doc(bot)) diff --git a/bot/exts/info/doc/__init__.py b/bot/exts/info/doc/__init__.py new file mode 100644 index 000000000..38a8975c0 --- /dev/null +++ b/bot/exts/info/doc/__init__.py @@ -0,0 +1,16 @@ +from bot.bot import Bot +from ._redis_cache import DocRedisCache + +MAX_SIGNATURE_AMOUNT = 3 +PRIORITY_PACKAGES = ( +    "python", +) +NAMESPACE = "doc" + +doc_cache = DocRedisCache(namespace=NAMESPACE) + + +def setup(bot: Bot) -> None: +    """Load the Doc cog.""" +    from ._cog import DocCog +    bot.add_cog(DocCog(bot)) diff --git a/bot/exts/info/doc/_batch_parser.py b/bot/exts/info/doc/_batch_parser.py new file mode 100644 index 000000000..369bb462c --- /dev/null +++ b/bot/exts/info/doc/_batch_parser.py @@ -0,0 +1,186 @@ +from __future__ import annotations + +import asyncio +import collections +import logging +from collections import defaultdict +from contextlib import suppress +from operator import attrgetter +from typing import Deque, Dict, List, NamedTuple, Optional, Union + +import discord +from bs4 import BeautifulSoup + +import bot +from bot.constants import Channels +from bot.utils import scheduling +from . import _cog, doc_cache +from ._parsing import get_symbol_markdown + +log = logging.getLogger(__name__) + + +class StaleInventoryNotifier: +    """Handle sending notifications about stale inventories through `DocItem`s to dev log.""" + +    def __init__(self): +        self._init_task = bot.instance.loop.create_task( +            self._init_channel(), +            name="StaleInventoryNotifier channel init" +        ) +        self._warned_urls = set() + +    async def _init_channel(self) -> None: +        """Wait for guild and get channel.""" +        await bot.instance.wait_until_guild_available() +        self._dev_log = bot.instance.get_channel(Channels.dev_log) + +    async def send_warning(self, doc_item: _cog.DocItem) -> None: +        """Send a warning to dev log if one wasn't already sent for `item`'s url.""" +        if doc_item.url not in self._warned_urls: +            self._warned_urls.add(doc_item.url) +            await self._init_task +            embed = discord.Embed( +                description=f"Doc item `{doc_item.symbol_id=}` present in loaded documentation inventories " +                            f"not found on [site]({doc_item.url}), inventories may need to be refreshed." +            ) +            await self._dev_log.send(embed=embed) + + +class QueueItem(NamedTuple): +    """Contains a `DocItem` and the `BeautifulSoup` object needed to parse it.""" + +    doc_item: _cog.DocItem +    soup: BeautifulSoup + +    def __eq__(self, other: Union[QueueItem, _cog.DocItem]): +        if isinstance(other, _cog.DocItem): +            return self.doc_item == other +        return NamedTuple.__eq__(self, other) + + +class ParseResultFuture(asyncio.Future): +    """ +    Future with metadata for the parser class. + +    `user_requested` is set by the parser when a Future is requested by an user and moved to the front, +    allowing the futures to only be waited for when clearing if they were user requested. +    """ + +    def __init__(self): +        super().__init__() +        self.user_requested = False + + +class BatchParser: +    """ +    Get the Markdown of all symbols on a page and send them to redis when a symbol is requested. + +    DocItems are added through the `add_item` method which adds them to the `_page_doc_items` dict. +    `get_markdown` is used to fetch the Markdown; when this is used for the first time on a page, +    all of the symbols are queued to be parsed to avoid multiple web requests to the same page. +    """ + +    def __init__(self): +        self._queue: Deque[QueueItem] = collections.deque() +        self._page_doc_items: Dict[str, List[_cog.DocItem]] = defaultdict(list) +        self._item_futures: Dict[_cog.DocItem, ParseResultFuture] = defaultdict(ParseResultFuture) +        self._parse_task = None + +        self.stale_inventory_notifier = StaleInventoryNotifier() + +    async def get_markdown(self, doc_item: _cog.DocItem) -> Optional[str]: +        """ +        Get the result Markdown of `doc_item`. + +        If no symbols were fetched from `doc_item`s page before, +        the HTML has to be fetched and then all items from the page are put into the parse queue. + +        Not safe to run while `self.clear` is running. +        """ +        if doc_item not in self._item_futures and doc_item not in self._queue: +            self._item_futures[doc_item].user_requested = True + +            async with bot.instance.http_session.get(doc_item.url) as response: +                soup = await bot.instance.loop.run_in_executor( +                    None, +                    BeautifulSoup, +                    await response.text(encoding="utf8"), +                    "lxml", +                ) + +            self._queue.extendleft(QueueItem(item, soup) for item in self._page_doc_items[doc_item.url]) +            log.debug(f"Added items from {doc_item.url} to the parse queue.") + +            if self._parse_task is None: +                self._parse_task = scheduling.create_task(self._parse_queue(), name="Queue parse") +        else: +            self._item_futures[doc_item].user_requested = True +        with suppress(ValueError): +            # If the item is not in the queue then the item is already parsed or is being parsed +            self._move_to_front(doc_item) +        return await self._item_futures[doc_item] + +    async def _parse_queue(self) -> None: +        """ +        Parse all items from the queue, setting their result Markdown on the futures and sending them to redis. + +        The coroutine will run as long as the queue is not empty, resetting `self._parse_task` to None when finished. +        """ +        log.trace("Starting queue parsing.") +        try: +            while self._queue: +                item, soup = self._queue.pop() +                markdown = None + +                if (future := self._item_futures[item]).done(): +                    # Some items are present in the inventories multiple times under different symbol names, +                    # if we already parsed an equal item, we can just skip it. +                    continue + +                try: +                    markdown = await bot.instance.loop.run_in_executor(None, get_symbol_markdown, soup, item) +                    if markdown is not None: +                        await doc_cache.set(item, markdown) +                    else: +                        # Don't wait for this coro as the parsing doesn't depend on anything it does. +                        scheduling.create_task( +                            self.stale_inventory_notifier.send_warning(item), name="Stale inventory warning" +                        ) +                except Exception: +                    log.exception(f"Unexpected error when handling {item}") +                future.set_result(markdown) +                del self._item_futures[item] +                await asyncio.sleep(0.1) +        finally: +            self._parse_task = None +            log.trace("Finished parsing queue.") + +    def _move_to_front(self, item: Union[QueueItem, _cog.DocItem]) -> None: +        """Move `item` to the front of the parse queue.""" +        # The parse queue stores soups along with the doc symbols in QueueItem objects, +        # in case we're moving a DocItem we have to get the associated QueueItem first and then move it. +        item_index = self._queue.index(item) +        queue_item = self._queue[item_index] +        del self._queue[item_index] + +        self._queue.append(queue_item) +        log.trace(f"Moved {item} to the front of the queue.") + +    def add_item(self, doc_item: _cog.DocItem) -> None: +        """Map a DocItem to its page so that the symbol will be parsed once the page is requested.""" +        self._page_doc_items[doc_item.url].append(doc_item) + +    async def clear(self) -> None: +        """ +        Clear all internal symbol data. + +        Wait for all user-requested symbols to be parsed before clearing the parser. +        """ +        for future in filter(attrgetter("user_requested"), self._item_futures.values()): +            await future +        if self._parse_task is not None: +            self._parse_task.cancel() +        self._queue.clear() +        self._page_doc_items.clear() +        self._item_futures.clear() diff --git a/bot/exts/info/doc/_cog.py b/bot/exts/info/doc/_cog.py new file mode 100644 index 000000000..2a8016fb8 --- /dev/null +++ b/bot/exts/info/doc/_cog.py @@ -0,0 +1,442 @@ +from __future__ import annotations + +import asyncio +import logging +import sys +import textwrap +from collections import defaultdict +from contextlib import suppress +from types import SimpleNamespace +from typing import Dict, NamedTuple, Optional, Tuple, Union + +import aiohttp +import discord +from discord.ext import commands + +from bot.bot import Bot +from bot.constants import MODERATION_ROLES, RedirectOutput +from bot.converters import Inventory, PackageName, ValidURL, allowed_strings +from bot.pagination import LinePaginator +from bot.utils.lock import SharedEvent, lock +from bot.utils.messages import send_denial, wait_for_deletion +from bot.utils.scheduling import Scheduler +from . import NAMESPACE, PRIORITY_PACKAGES, _batch_parser, doc_cache +from ._inventory_parser import InventoryDict, fetch_inventory + +log = logging.getLogger(__name__) + +# symbols with a group contained here will get the group prefixed on duplicates +FORCE_PREFIX_GROUPS = ( +    "2to3fixer", +    "token", +    "label", +    "pdbcommand", +    "term", +) +NOT_FOUND_DELETE_DELAY = RedirectOutput.delete_delay +# Delay to wait before trying to reach a rescheduled inventory again, in minutes +FETCH_RESCHEDULE_DELAY = SimpleNamespace(first=2, repeated=5) + +COMMAND_LOCK_SINGLETON = "inventory refresh" + + +class DocItem(NamedTuple): +    """Holds inventory symbol information.""" + +    package: str  # Name of the package name the symbol is from +    group: str  # Interpshinx "role" of the symbol, for example `label` or `method` +    base_url: str  # Absolute path to to which the relative path resolves, same for all items with the same package +    relative_url_path: str  # Relative path to the page where the symbol is located +    symbol_id: str  # Fragment id used to locate the symbol on the page + +    @property +    def url(self) -> str: +        """Return the absolute url to the symbol.""" +        return self.base_url + self.relative_url_path + + +class DocCog(commands.Cog): +    """A set of commands for querying & displaying documentation.""" + +    def __init__(self, bot: Bot): +        # Contains URLs to documentation home pages. +        # Used to calculate inventory diffs on refreshes and to display all currently stored inventories. +        self.base_urls = {} +        self.bot = bot +        self.doc_symbols: Dict[str, DocItem] = {}  # Maps symbol names to objects containing their metadata. +        self.item_fetcher = _batch_parser.BatchParser() +        # Maps a conflicting symbol name to a list of the new, disambiguated names created from conflicts with the name. +        self.renamed_symbols = defaultdict(list) + +        self.inventory_scheduler = Scheduler(self.__class__.__name__) + +        self.refresh_event = asyncio.Event() +        self.refresh_event.set() +        self.symbol_get_event = SharedEvent() + +        self.init_refresh_task = self.bot.loop.create_task( +            self.init_refresh_inventory(), +            name="Doc inventory init" +        ) + +    @lock(NAMESPACE, COMMAND_LOCK_SINGLETON, raise_error=True) +    async def init_refresh_inventory(self) -> None: +        """Refresh documentation inventory on cog initialization.""" +        await self.bot.wait_until_guild_available() +        await self.refresh_inventories() + +    def update_single(self, package_name: str, base_url: str, inventory: InventoryDict) -> None: +        """ +        Build the inventory for a single package. + +        Where: +            * `package_name` is the package name to use in logs and when qualifying symbols +            * `base_url` is the root documentation URL for the specified package, used to build +                absolute paths that link to specific symbols +            * `package` is the content of a intersphinx inventory. +        """ +        self.base_urls[package_name] = base_url + +        for group, items in inventory.items(): +            for symbol_name, relative_doc_url in items: + +                # e.g. get 'class' from 'py:class' +                group_name = group.split(":")[1] +                symbol_name = self.ensure_unique_symbol_name( +                    package_name, +                    group_name, +                    symbol_name, +                ) + +                relative_url_path, _, symbol_id = relative_doc_url.partition("#") +                # Intern fields that have shared content so we're not storing unique strings for every object +                doc_item = DocItem( +                    package_name, +                    sys.intern(group_name), +                    base_url, +                    sys.intern(relative_url_path), +                    symbol_id, +                ) +                self.doc_symbols[symbol_name] = doc_item +                self.item_fetcher.add_item(doc_item) + +        log.trace(f"Fetched inventory for {package_name}.") + +    async def update_or_reschedule_inventory( +        self, +        api_package_name: str, +        base_url: str, +        inventory_url: str, +    ) -> None: +        """ +        Update the cog's inventories, or reschedule this method to execute again if the remote inventory is unreachable. + +        The first attempt is rescheduled to execute in `FETCH_RESCHEDULE_DELAY.first` minutes, the subsequent attempts +        in `FETCH_RESCHEDULE_DELAY.repeated` minutes. +        """ +        package = await fetch_inventory(inventory_url) + +        if not package: +            if api_package_name in self.inventory_scheduler: +                self.inventory_scheduler.cancel(api_package_name) +                delay = FETCH_RESCHEDULE_DELAY.repeated +            else: +                delay = FETCH_RESCHEDULE_DELAY.first +            log.info(f"Failed to fetch inventory; attempting again in {delay} minutes.") +            self.inventory_scheduler.schedule_later( +                delay*60, +                api_package_name, +                self.update_or_reschedule_inventory(api_package_name, base_url, inventory_url), +            ) +        else: +            self.update_single(api_package_name, base_url, package) + +    def ensure_unique_symbol_name(self, package_name: str, group_name: str, symbol_name: str) -> str: +        """ +        Ensure `symbol_name` doesn't overwrite an another symbol in `doc_symbols`. + +        For conflicts, rename either the current symbol or the existing symbol with which it conflicts. +        Store the new name in `renamed_symbols` and return the name to use for the symbol. + +        If the existing symbol was renamed or there was no conflict, the returned name is equivalent to `symbol_name`. +        """ +        if (item := self.doc_symbols.get(symbol_name)) is None: +            return symbol_name  # There's no conflict so it's fine to simply use the given symbol name. + +        def rename(prefix: str, *, rename_extant: bool = False) -> str: +            new_name = f"{prefix}.{symbol_name}" +            if new_name in self.doc_symbols: +                # If there's still a conflict, qualify the name further. +                if rename_extant: +                    new_name = f"{item.package}.{item.group}.{symbol_name}" +                else: +                    new_name = f"{package_name}.{group_name}.{symbol_name}" + +            self.renamed_symbols[symbol_name].append(new_name) + +            if rename_extant: +                # Instead of renaming the current symbol, rename the symbol with which it conflicts. +                self.doc_symbols[new_name] = self.doc_symbols[symbol_name] +                return symbol_name +            else: +                return new_name + +        # Certain groups are added as prefixes to disambiguate the symbols. +        if group_name in FORCE_PREFIX_GROUPS: +            return rename(group_name) + +        # The existing symbol with which the current symbol conflicts should have a group prefix. +        # It currently doesn't have the group prefix because it's only added once there's a conflict. +        elif item.group in FORCE_PREFIX_GROUPS: +            return rename(item.group, rename_extant=True) + +        elif package_name in PRIORITY_PACKAGES: +            return rename(item.package, rename_extant=True) + +        # If we can't specially handle the symbol through its group or package, +        # fall back to prepending its package name to the front. +        else: +            return rename(package_name) + +    async def refresh_inventories(self) -> None: +        """Refresh internal documentation inventories.""" +        self.refresh_event.clear() +        await self.symbol_get_event.wait() +        log.debug("Refreshing documentation inventory...") +        self.inventory_scheduler.cancel_all() + +        self.base_urls.clear() +        self.doc_symbols.clear() +        self.renamed_symbols.clear() +        await self.item_fetcher.clear() + +        coros = [ +            self.update_or_reschedule_inventory( +                package["package"], package["base_url"], package["inventory_url"] +            ) for package in await self.bot.api_client.get("bot/documentation-links") +        ] +        await asyncio.gather(*coros) +        log.debug("Finished inventory refresh.") +        self.refresh_event.set() + +    def get_symbol_item(self, symbol_name: str) -> Tuple[str, Optional[DocItem]]: +        """ +        Get the `DocItem` and the symbol name used to fetch it from the `doc_symbols` dict. + +        If the doc item is not found directly from the passed in name and the name contains a space, +        the first word of the name will be attempted to be used to get the item. +        """ +        doc_item = self.doc_symbols.get(symbol_name) +        if doc_item is None and " " in symbol_name: +            symbol_name = symbol_name.split(" ", maxsplit=1)[0] +            doc_item = self.doc_symbols.get(symbol_name) + +        return symbol_name, doc_item + +    async def get_symbol_markdown(self, doc_item: DocItem) -> str: +        """ +        Get the Markdown from the symbol `doc_item` refers to. + +        First a redis lookup is attempted, if that fails the `item_fetcher` +        is used to fetch the page and parse the HTML from it into Markdown. +        """ +        markdown = await doc_cache.get(doc_item) + +        if markdown is None: +            log.debug(f"Redis cache miss with {doc_item}.") +            try: +                markdown = await self.item_fetcher.get_markdown(doc_item) + +            except aiohttp.ClientError as e: +                log.warning(f"A network error has occurred when requesting parsing of {doc_item}.", exc_info=e) +                return "Unable to parse the requested symbol due to a network error." + +            except Exception: +                log.exception(f"An unexpected error has occurred when requesting parsing of {doc_item}.") +                return "Unable to parse the requested symbol due to an error." + +            if markdown is None: +                return "Unable to parse the requested symbol." +        return markdown + +    async def create_symbol_embed(self, symbol_name: str) -> Optional[discord.Embed]: +        """ +        Attempt to scrape and fetch the data for the given `symbol_name`, and build an embed from its contents. + +        If the symbol is known, an Embed with documentation about it is returned. + +        First check the DocRedisCache before querying the cog's `BatchParser`. +        """ +        log.trace(f"Building embed for symbol `{symbol_name}`") +        if not self.refresh_event.is_set(): +            log.debug("Waiting for inventories to be refreshed before processing item.") +            await self.refresh_event.wait() +        # Ensure a refresh can't run in case of a context switch until the with block is exited +        with self.symbol_get_event: +            symbol_name, doc_item = self.get_symbol_item(symbol_name) +            if doc_item is None: +                log.debug("Symbol does not exist.") +                return None + +            self.bot.stats.incr(f"doc_fetches.{doc_item.package}") + +            # Show all symbols with the same name that were renamed in the footer, +            # with a max of 200 chars. +            if symbol_name in self.renamed_symbols: +                renamed_symbols = ", ".join(self.renamed_symbols[symbol_name]) +                footer_text = textwrap.shorten("Similar names: " + renamed_symbols, 200, placeholder=" ...") +            else: +                footer_text = "" + +            embed = discord.Embed( +                title=discord.utils.escape_markdown(symbol_name), +                url=f"{doc_item.url}#{doc_item.symbol_id}", +                description=await self.get_symbol_markdown(doc_item) +            ) +            embed.set_footer(text=footer_text) +            return embed + +    @commands.group(name="docs", aliases=("doc", "d"), invoke_without_command=True) +    async def docs_group(self, ctx: commands.Context, *, symbol_name: Optional[str]) -> None: +        """Look up documentation for Python symbols.""" +        await self.get_command(ctx, symbol_name=symbol_name) + +    @docs_group.command(name="getdoc", aliases=("g",)) +    async def get_command(self, ctx: commands.Context, *, symbol_name: Optional[str]) -> None: +        """ +        Return a documentation embed for a given symbol. + +        If no symbol is given, return a list of all available inventories. + +        Examples: +            !docs +            !docs aiohttp +            !docs aiohttp.ClientSession +            !docs getdoc aiohttp.ClientSession +        """ +        if not symbol_name: +            inventory_embed = discord.Embed( +                title=f"All inventories (`{len(self.base_urls)}` total)", +                colour=discord.Colour.blue() +            ) + +            lines = sorted(f"• [`{name}`]({url})" for name, url in self.base_urls.items()) +            if self.base_urls: +                await LinePaginator.paginate(lines, ctx, inventory_embed, max_size=400, empty=False) + +            else: +                inventory_embed.description = "Hmmm, seems like there's nothing here yet." +                await ctx.send(embed=inventory_embed) + +        else: +            symbol = symbol_name.strip("`") +            async with ctx.typing(): +                doc_embed = await self.create_symbol_embed(symbol) + +            if doc_embed is None: +                error_message = await send_denial(ctx, "No documentation found for the requested symbol.") +                await wait_for_deletion(error_message, (ctx.author.id,), timeout=NOT_FOUND_DELETE_DELAY) +                with suppress(discord.NotFound): +                    await ctx.message.delete() +                with suppress(discord.NotFound): +                    await error_message.delete() +            else: +                msg = await ctx.send(embed=doc_embed) +                await wait_for_deletion(msg, (ctx.author.id,)) + +    @docs_group.command(name="setdoc", aliases=("s",)) +    @commands.has_any_role(*MODERATION_ROLES) +    @lock(NAMESPACE, COMMAND_LOCK_SINGLETON, raise_error=True) +    async def set_command( +        self, +        ctx: commands.Context, +        package_name: PackageName, +        base_url: ValidURL, +        inventory: Inventory, +    ) -> None: +        """ +        Adds a new documentation metadata object to the site's database. + +        The database will update the object, should an existing item with the specified `package_name` already exist. + +        Example: +            !docs setdoc \ +                    python \ +                    https://docs.python.org/3/ \ +                    https://docs.python.org/3/objects.inv +        """ +        if not base_url.endswith("/"): +            raise commands.BadArgument("The base url must end with a slash.") +        inventory_url, inventory_dict = inventory +        body = { +            "package": package_name, +            "base_url": base_url, +            "inventory_url": inventory_url +        } +        await self.bot.api_client.post("bot/documentation-links", json=body) + +        log.info( +            f"User @{ctx.author} ({ctx.author.id}) added a new documentation package:\n" +            + "\n".join(f"{key}: {value}" for key, value in body.items()) +        ) + +        self.update_single(package_name, base_url, inventory_dict) +        await ctx.send(f"Added the package `{package_name}` to the database and updated the inventories.") + +    @docs_group.command(name="deletedoc", aliases=("removedoc", "rm", "d")) +    @commands.has_any_role(*MODERATION_ROLES) +    @lock(NAMESPACE, COMMAND_LOCK_SINGLETON, raise_error=True) +    async def delete_command(self, ctx: commands.Context, package_name: PackageName) -> None: +        """ +        Removes the specified package from the database. + +        Example: +            !docs deletedoc aiohttp +        """ +        await self.bot.api_client.delete(f"bot/documentation-links/{package_name}") + +        async with ctx.typing(): +            await self.refresh_inventories() +            await doc_cache.delete(package_name) +        await ctx.send(f"Successfully deleted `{package_name}` and refreshed the inventories.") + +    @docs_group.command(name="refreshdoc", aliases=("rfsh", "r")) +    @commands.has_any_role(*MODERATION_ROLES) +    @lock(NAMESPACE, COMMAND_LOCK_SINGLETON, raise_error=True) +    async def refresh_command(self, ctx: commands.Context) -> None: +        """Refresh inventories and show the difference.""" +        old_inventories = set(self.base_urls) +        with ctx.typing(): +            await self.refresh_inventories() +        new_inventories = set(self.base_urls) + +        if added := ", ".join(new_inventories - old_inventories): +            added = "+ " + added + +        if removed := ", ".join(old_inventories - new_inventories): +            removed = "- " + removed + +        embed = discord.Embed( +            title="Inventories refreshed", +            description=f"```diff\n{added}\n{removed}```" if added or removed else "" +        ) +        await ctx.send(embed=embed) + +    @docs_group.command(name="cleardoccache", aliases=("deletedoccache",)) +    @commands.has_any_role(*MODERATION_ROLES) +    async def clear_cache_command( +        self, +        ctx: commands.Context, +        package_name: Union[PackageName, allowed_strings("*")]  # noqa: F722 +    ) -> None: +        """Clear the persistent redis cache for `package`.""" +        if await doc_cache.delete(package_name): +            await ctx.send(f"Successfully cleared the cache for `{package_name}`.") +        else: +            await ctx.send("No keys matching the package found.") + +    def cog_unload(self) -> None: +        """Clear scheduled inventories, queued symbols and cleanup task on cog unload.""" +        self.inventory_scheduler.cancel_all() +        self.init_refresh_task.cancel() +        asyncio.create_task(self.item_fetcher.clear(), name="DocCog.item_fetcher unload clear") diff --git a/bot/exts/info/doc/_html.py b/bot/exts/info/doc/_html.py new file mode 100644 index 000000000..94efd81b7 --- /dev/null +++ b/bot/exts/info/doc/_html.py @@ -0,0 +1,136 @@ +import logging +import re +from functools import partial +from typing import Callable, Container, Iterable, List, Union + +from bs4 import BeautifulSoup +from bs4.element import NavigableString, PageElement, SoupStrainer, Tag + +from . import MAX_SIGNATURE_AMOUNT + +log = logging.getLogger(__name__) + +_UNWANTED_SIGNATURE_SYMBOLS_RE = re.compile(r"\[source]|\\\\|¶") +_SEARCH_END_TAG_ATTRS = ( +    "data", +    "function", +    "class", +    "exception", +    "seealso", +    "section", +    "rubric", +    "sphinxsidebar", +) + + +class Strainer(SoupStrainer): +    """Subclass of SoupStrainer to allow matching of both `Tag`s and `NavigableString`s.""" + +    def __init__(self, *, include_strings: bool, **kwargs): +        self.include_strings = include_strings +        passed_text = kwargs.pop("text", None) +        if passed_text is not None: +            log.warning("`text` is not a supported kwarg in the custom strainer.") +        super().__init__(**kwargs) + +    Markup = Union[PageElement, List["Markup"]] + +    def search(self, markup: Markup) -> Union[PageElement, str]: +        """Extend default SoupStrainer behaviour to allow matching both `Tag`s` and `NavigableString`s.""" +        if isinstance(markup, str): +            # Let everything through the text filter if we're including strings and tags. +            if not self.name and not self.attrs and self.include_strings: +                return markup +        else: +            return super().search(markup) + + +def _find_elements_until_tag( +    start_element: PageElement, +    end_tag_filter: Union[Container[str], Callable[[Tag], bool]], +    *, +    func: Callable, +    include_strings: bool = False, +    limit: int = None, +) -> List[Union[Tag, NavigableString]]: +    """ +    Get all elements up to `limit` or until a tag matching `end_tag_filter` is found. + +    `end_tag_filter` can be either a container of string names to check against, +    or a filtering callable that's applied to tags. + +    When `include_strings` is True, `NavigableString`s from the document will be included in the result along `Tag`s. + +    `func` takes in a BeautifulSoup unbound method for finding multiple elements, such as `BeautifulSoup.find_all`. +    The method is then iterated over and all elements until the matching tag or the limit are added to the return list. +    """ +    use_container_filter = not callable(end_tag_filter) +    elements = [] + +    for element in func(start_element, name=Strainer(include_strings=include_strings), limit=limit): +        if isinstance(element, Tag): +            if use_container_filter: +                if element.name in end_tag_filter: +                    break +            elif end_tag_filter(element): +                break +        elements.append(element) + +    return elements + + +_find_next_children_until_tag = partial(_find_elements_until_tag, func=partial(BeautifulSoup.find_all, recursive=False)) +_find_recursive_children_until_tag = partial(_find_elements_until_tag, func=BeautifulSoup.find_all) +_find_next_siblings_until_tag = partial(_find_elements_until_tag, func=BeautifulSoup.find_next_siblings) +_find_previous_siblings_until_tag = partial(_find_elements_until_tag, func=BeautifulSoup.find_previous_siblings) + + +def _class_filter_factory(class_names: Iterable[str]) -> Callable[[Tag], bool]: +    """Create callable that returns True when the passed in tag's class is in `class_names` or when it's a table.""" +    def match_tag(tag: Tag) -> bool: +        for attr in class_names: +            if attr in tag.get("class", ()): +                return True +        return tag.name == "table" + +    return match_tag + + +def get_general_description(start_element: Tag) -> List[Union[Tag, NavigableString]]: +    """ +    Get page content to a table or a tag with its class in `SEARCH_END_TAG_ATTRS`. + +    A headerlink tag is attempted to be found to skip repeating the symbol information in the description. +    If it's found it's used as the tag to start the search from instead of the `start_element`. +    """ +    child_tags = _find_recursive_children_until_tag(start_element, _class_filter_factory(["section"]), limit=100) +    header = next(filter(_class_filter_factory(["headerlink"]), child_tags), None) +    start_tag = header.parent if header is not None else start_element +    return _find_next_siblings_until_tag(start_tag, _class_filter_factory(_SEARCH_END_TAG_ATTRS), include_strings=True) + + +def get_dd_description(symbol: PageElement) -> List[Union[Tag, NavigableString]]: +    """Get the contents of the next dd tag, up to a dt or a dl tag.""" +    description_tag = symbol.find_next("dd") +    return _find_next_children_until_tag(description_tag, ("dt", "dl"), include_strings=True) + + +def get_signatures(start_signature: PageElement) -> List[str]: +    """ +    Collect up to `_MAX_SIGNATURE_AMOUNT` signatures from dt tags around the `start_signature` dt tag. + +    First the signatures under the `start_signature` are included; +    if less than 2 are found, tags above the start signature are added to the result if any are present. +    """ +    signatures = [] +    for element in ( +            *reversed(_find_previous_siblings_until_tag(start_signature, ("dd",), limit=2)), +            start_signature, +            *_find_next_siblings_until_tag(start_signature, ("dd",), limit=2), +    )[-MAX_SIGNATURE_AMOUNT:]: +        signature = _UNWANTED_SIGNATURE_SYMBOLS_RE.sub("", element.text) + +        if signature: +            signatures.append(signature) + +    return signatures diff --git a/bot/exts/info/doc/_inventory_parser.py b/bot/exts/info/doc/_inventory_parser.py new file mode 100644 index 000000000..80d5841a0 --- /dev/null +++ b/bot/exts/info/doc/_inventory_parser.py @@ -0,0 +1,126 @@ +import logging +import re +import zlib +from collections import defaultdict +from typing import AsyncIterator, DefaultDict, List, Optional, Tuple + +import aiohttp + +import bot + +log = logging.getLogger(__name__) + +FAILED_REQUEST_ATTEMPTS = 3 +_V2_LINE_RE = re.compile(r'(?x)(.+?)\s+(\S*:\S*)\s+(-?\d+)\s+?(\S*)\s+(.*)') + +InventoryDict = DefaultDict[str, List[Tuple[str, str]]] + + +class ZlibStreamReader: +    """Class used for decoding zlib data of a stream line by line.""" + +    READ_CHUNK_SIZE = 16 * 1024 + +    def __init__(self, stream: aiohttp.StreamReader) -> None: +        self.stream = stream + +    async def _read_compressed_chunks(self) -> AsyncIterator[bytes]: +        """Read zlib data in `READ_CHUNK_SIZE` sized chunks and decompress.""" +        decompressor = zlib.decompressobj() +        async for chunk in self.stream.iter_chunked(self.READ_CHUNK_SIZE): +            yield decompressor.decompress(chunk) + +        yield decompressor.flush() + +    async def __aiter__(self) -> AsyncIterator[str]: +        """Yield lines of decompressed text.""" +        buf = b'' +        async for chunk in self._read_compressed_chunks(): +            buf += chunk +            pos = buf.find(b'\n') +            while pos != -1: +                yield buf[:pos].decode() +                buf = buf[pos + 1:] +                pos = buf.find(b'\n') + + +async def _load_v1(stream: aiohttp.StreamReader) -> InventoryDict: +    invdata = defaultdict(list) + +    async for line in stream: +        name, type_, location = line.decode().rstrip().split(maxsplit=2) +        # version 1 did not add anchors to the location +        if type_ == "mod": +            type_ = "py:module" +            location += "#module-" + name +        else: +            type_ = "py:" + type_ +            location += "#" + name +        invdata[type_].append((name, location)) +    return invdata + + +async def _load_v2(stream: aiohttp.StreamReader) -> InventoryDict: +    invdata = defaultdict(list) + +    async for line in ZlibStreamReader(stream): +        m = _V2_LINE_RE.match(line.rstrip()) +        name, type_, _prio, location, _dispname = m.groups()  # ignore the parsed items we don't need +        if location.endswith("$"): +            location = location[:-1] + name + +        invdata[type_].append((name, location)) +    return invdata + + +async def _fetch_inventory(url: str) -> InventoryDict: +    """Fetch, parse and return an intersphinx inventory file from an url.""" +    timeout = aiohttp.ClientTimeout(sock_connect=5, sock_read=5) +    async with bot.instance.http_session.get(url, timeout=timeout, raise_for_status=True) as response: +        stream = response.content + +        inventory_header = (await stream.readline()).decode().rstrip() +        inventory_version = int(inventory_header[-1:]) +        await stream.readline()  # skip project name +        await stream.readline()  # skip project version + +        if inventory_version == 1: +            return await _load_v1(stream) + +        elif inventory_version == 2: +            if b"zlib" not in await stream.readline(): +                raise ValueError(f"Invalid inventory file at url {url}.") +            return await _load_v2(stream) + +        raise ValueError(f"Invalid inventory file at url {url}.") + + +async def fetch_inventory(url: str) -> Optional[InventoryDict]: +    """ +    Get an inventory dict from `url`, retrying `FAILED_REQUEST_ATTEMPTS` times on errors. + +    `url` should point at a valid sphinx objects.inv inventory file, which will be parsed into the +    inventory dict in the format of {"domain:role": [("symbol_name", "relative_url_to_symbol"), ...], ...} +    """ +    for attempt in range(1, FAILED_REQUEST_ATTEMPTS+1): +        try: +            inventory = await _fetch_inventory(url) +        except aiohttp.ClientConnectorError: +            log.warning( +                f"Failed to connect to inventory url at {url}; " +                f"trying again ({attempt}/{FAILED_REQUEST_ATTEMPTS})." +            ) +        except aiohttp.ClientError: +            log.error( +                f"Failed to get inventory from {url}; " +                f"trying again ({attempt}/{FAILED_REQUEST_ATTEMPTS})." +            ) +        except Exception: +            log.exception( +                f"An unexpected error has occurred during fetching of {url}; " +                f"trying again ({attempt}/{FAILED_REQUEST_ATTEMPTS})." +            ) +        else: +            return inventory + +    return None diff --git a/bot/exts/info/doc/_markdown.py b/bot/exts/info/doc/_markdown.py new file mode 100644 index 000000000..1b7d8232b --- /dev/null +++ b/bot/exts/info/doc/_markdown.py @@ -0,0 +1,58 @@ +from urllib.parse import urljoin + +from bs4.element import PageElement +from markdownify import MarkdownConverter + + +class DocMarkdownConverter(MarkdownConverter): +    """Subclass markdownify's MarkdownCoverter to provide custom conversion methods.""" + +    def __init__(self, *, page_url: str, **options): +        super().__init__(**options) +        self.page_url = page_url + +    def convert_li(self, el: PageElement, text: str, convert_as_inline: bool) -> str: +        """Fix markdownify's erroneous indexing in ol tags.""" +        parent = el.parent +        if parent is not None and parent.name == "ol": +            li_tags = parent.find_all("li") +            bullet = f"{li_tags.index(el)+1}." +        else: +            depth = -1 +            while el: +                if el.name == "ul": +                    depth += 1 +                el = el.parent +            bullets = self.options["bullets"] +            bullet = bullets[depth % len(bullets)] +        return f"{bullet} {text}\n" + +    def convert_hn(self, _n: int, el: PageElement, text: str, convert_as_inline: bool) -> str: +        """Convert h tags to bold text with ** instead of adding #.""" +        if convert_as_inline: +            return text +        return f"**{text}**\n\n" + +    def convert_code(self, el: PageElement, text: str, convert_as_inline: bool) -> str: +        """Undo `markdownify`s underscore escaping.""" +        return f"`{text}`".replace("\\", "") + +    def convert_pre(self, el: PageElement, text: str, convert_as_inline: bool) -> str: +        """Wrap any codeblocks in `py` for syntax highlighting.""" +        code = "".join(el.strings) +        return f"```py\n{code}```" + +    def convert_a(self, el: PageElement, text: str, convert_as_inline: bool) -> str: +        """Resolve relative URLs to `self.page_url`.""" +        el["href"] = urljoin(self.page_url, el["href"]) +        return super().convert_a(el, text, convert_as_inline) + +    def convert_p(self, el: PageElement, text: str, convert_as_inline: bool) -> str: +        """Include only one newline instead of two when the parent is a li tag.""" +        if convert_as_inline: +            return text + +        parent = el.parent +        if parent is not None and parent.name == "li": +            return f"{text}\n" +        return super().convert_p(el, text, convert_as_inline) diff --git a/bot/exts/info/doc/_parsing.py b/bot/exts/info/doc/_parsing.py new file mode 100644 index 000000000..bf840b96f --- /dev/null +++ b/bot/exts/info/doc/_parsing.py @@ -0,0 +1,256 @@ +from __future__ import annotations + +import logging +import re +import string +import textwrap +from collections import namedtuple +from typing import Collection, Iterable, Iterator, List, Optional, TYPE_CHECKING, Union + +from bs4 import BeautifulSoup +from bs4.element import NavigableString, Tag + +from bot.utils.helpers import find_nth_occurrence +from . import MAX_SIGNATURE_AMOUNT +from ._html import get_dd_description, get_general_description, get_signatures +from ._markdown import DocMarkdownConverter +if TYPE_CHECKING: +    from ._cog import DocItem + +log = logging.getLogger(__name__) + +_WHITESPACE_AFTER_NEWLINES_RE = re.compile(r"(?<=\n\n)(\s+)") +_PARAMETERS_RE = re.compile(r"\((.+)\)") + +_NO_SIGNATURE_GROUPS = { +    "attribute", +    "envvar", +    "setting", +    "tempaltefilter", +    "templatetag", +    "term", +} +_EMBED_CODE_BLOCK_LINE_LENGTH = 61 +# _MAX_SIGNATURE_AMOUNT code block wrapped lines with py syntax highlight +_MAX_SIGNATURES_LENGTH = (_EMBED_CODE_BLOCK_LINE_LENGTH + 8) * MAX_SIGNATURE_AMOUNT +# Maximum embed description length - signatures on top +_MAX_DESCRIPTION_LENGTH = 2048 - _MAX_SIGNATURES_LENGTH +_TRUNCATE_STRIP_CHARACTERS = "!?:;." + string.whitespace + +BracketPair = namedtuple("BracketPair", ["opening_bracket", "closing_bracket"]) +_BRACKET_PAIRS = { +    "{": BracketPair("{", "}"), +    "(": BracketPair("(", ")"), +    "[": BracketPair("[", "]"), +    "<": BracketPair("<", ">"), +} + + +def _split_parameters(parameters_string: str) -> Iterator[str]: +    """ +    Split parameters of a signature into individual parameter strings on commas. + +    Long string literals are not accounted for. +    """ +    last_split = 0 +    depth = 0 +    current_search: Optional[BracketPair] = None + +    enumerated_string = enumerate(parameters_string) +    for index, character in enumerated_string: +        if character in {"'", '"'}: +            # Skip everything inside of strings, regardless of the depth. +            quote_character = character  # The closing quote must equal the opening quote. +            preceding_backslashes = 0 +            for _, character in enumerated_string: +                # If an odd number of backslashes precedes the quote, it was escaped. +                if character == quote_character and not preceding_backslashes % 2: +                    break +                if character == "\\": +                    preceding_backslashes += 1 +                else: +                    preceding_backslashes = 0 + +        elif current_search is None: +            if (current_search := _BRACKET_PAIRS.get(character)) is not None: +                depth = 1 +            elif character == ",": +                yield parameters_string[last_split:index] +                last_split = index + 1 + +        else: +            if character == current_search.opening_bracket: +                depth += 1 + +            elif character == current_search.closing_bracket: +                depth -= 1 +                if depth == 0: +                    current_search = None + +    yield parameters_string[last_split:] + + +def _truncate_signatures(signatures: Collection[str]) -> Union[List[str], Collection[str]]: +    """ +    Truncate passed signatures to not exceed `_MAX_SIGNATURES_LENGTH`. + +    If the signatures need to be truncated, parameters are collapsed until they fit withing the limit. +    Individual signatures can consist of max 1, 2, ..., `_MAX_SIGNATURE_AMOUNT` lines of text, +    inversely proportional to the amount of signatures. +    A maximum of `_MAX_SIGNATURE_AMOUNT` signatures is assumed to be passed. +    """ +    if sum(len(signature) for signature in signatures) <= _MAX_SIGNATURES_LENGTH: +        # Total length of signatures is under the length limit; no truncation needed. +        return signatures + +    max_signature_length = _EMBED_CODE_BLOCK_LINE_LENGTH * (MAX_SIGNATURE_AMOUNT + 1 - len(signatures)) +    formatted_signatures = [] +    for signature in signatures: +        signature = signature.strip() +        if len(signature) > max_signature_length: +            if (parameters_match := _PARAMETERS_RE.search(signature)) is None: +                # The signature has no parameters or the regex failed; perform a simple truncation of the text. +                formatted_signatures.append(textwrap.shorten(signature, max_signature_length, placeholder="...")) +                continue + +            truncated_signature = [] +            parameters_string = parameters_match[1] +            running_length = len(signature) - len(parameters_string) +            for parameter in _split_parameters(parameters_string): +                # Check if including this parameter would still be within the maximum length. +                if (len(parameter) + running_length) <= max_signature_length - 5:  # account for comma and placeholder +                    truncated_signature.append(parameter) +                    running_length += len(parameter) + 1 +                else: +                    # There's no more room for this parameter. Truncate the parameter list and put it in the signature. +                    truncated_signature.append(" ...") +                    formatted_signatures.append(signature.replace(parameters_string, ",".join(truncated_signature))) +                    break +        else: +            # The current signature is under the length limit; no truncation needed. +            formatted_signatures.append(signature) + +    return formatted_signatures + + +def _get_truncated_description( +    elements: Iterable[Union[Tag, NavigableString]], +    markdown_converter: DocMarkdownConverter, +    max_length: int, +    max_lines: int, +) -> str: +    """ +    Truncate the Markdown from `elements` to be at most `max_length` characters when rendered or `max_lines` newlines. + +    `max_length` limits the length of the rendered characters in the string, +    with the real string length limited to `_MAX_DESCRIPTION_LENGTH` to accommodate discord length limits. +    """ +    result = "" +    markdown_element_ends = []  # Stores indices into `result` which point to the end boundary of each Markdown element. +    rendered_length = 0 + +    tag_end_index = 0 +    for element in elements: +        is_tag = isinstance(element, Tag) +        element_length = len(element.text) if is_tag else len(element) + +        if rendered_length + element_length < max_length: +            if is_tag: +                element_markdown = markdown_converter.process_tag(element, convert_as_inline=False) +            else: +                element_markdown = markdown_converter.process_text(element) + +            rendered_length += element_length +            tag_end_index += len(element_markdown) + +            if not element_markdown.isspace(): +                markdown_element_ends.append(tag_end_index) +            result += element_markdown +        else: +            break + +    if not markdown_element_ends: +        return "" + +    # Determine the "hard" truncation index. Account for the ellipsis placeholder for the max length. +    newline_truncate_index = find_nth_occurrence(result, "\n", max_lines) +    if newline_truncate_index is not None and newline_truncate_index < _MAX_DESCRIPTION_LENGTH - 3: +        # Truncate based on maximum lines if there are more than the maximum number of lines. +        truncate_index = newline_truncate_index +    else: +        # There are less than the maximum number of lines; truncate based on the max char length. +        truncate_index = _MAX_DESCRIPTION_LENGTH - 3 + +    # Nothing needs to be truncated if the last element ends before the truncation index. +    if truncate_index >= markdown_element_ends[-1]: +        return result + +    # Determine the actual truncation index. +    possible_truncation_indices = [cut for cut in markdown_element_ends if cut < truncate_index] +    if not possible_truncation_indices: +        # In case there is no Markdown element ending before the truncation index, try to find a good cutoff point. +        force_truncated = result[:truncate_index] +        # If there is an incomplete codeblock, cut it out. +        if force_truncated.count("```") % 2: +            force_truncated = force_truncated[:force_truncated.rfind("```")] +        # Search for substrings to truncate at, with decreasing desirability. +        for string_ in ("\n\n", "\n", ". ", ", ", ",", " "): +            cutoff = force_truncated.rfind(string_) + +            if cutoff != -1: +                truncated_result = force_truncated[:cutoff] +                break +        else: +            truncated_result = force_truncated + +    else: +        # Truncate at the last Markdown element that comes before the truncation index. +        markdown_truncate_index = possible_truncation_indices[-1] +        truncated_result = result[:markdown_truncate_index] + +    return truncated_result.strip(_TRUNCATE_STRIP_CHARACTERS) + "..." + + +def _create_markdown(signatures: Optional[List[str]], description: Iterable[Tag], url: str) -> str: +    """ +    Create a Markdown string with the signatures at the top, and the converted html description below them. + +    The signatures are wrapped in python codeblocks, separated from the description by a newline. +    The result Markdown string is max 750 rendered characters for the description with signatures at the start. +    """ +    description = _get_truncated_description( +        description, +        markdown_converter=DocMarkdownConverter(bullets="•", page_url=url), +        max_length=750, +        max_lines=13 +    ) +    description = _WHITESPACE_AFTER_NEWLINES_RE.sub("", description) +    if signatures is not None: +        signature = "".join(f"```py\n{signature}```" for signature in _truncate_signatures(signatures)) +        return f"{signature}\n{description}" +    else: +        return description + + +def get_symbol_markdown(soup: BeautifulSoup, symbol_data: DocItem) -> Optional[str]: +    """ +    Return parsed Markdown of the passed item using the passed in soup, truncated to fit within a discord message. + +    The method of parsing and what information gets included depends on the symbol's group. +    """ +    symbol_heading = soup.find(id=symbol_data.symbol_id) +    if symbol_heading is None: +        return None +    signature = None +    # Modules, doc pages and labels don't point to description list tags but to tags like divs, +    # no special parsing can be done so we only try to include what's under them. +    if symbol_heading.name != "dt": +        description = get_general_description(symbol_heading) + +    elif symbol_data.group in _NO_SIGNATURE_GROUPS: +        description = get_dd_description(symbol_heading) + +    else: +        signature = get_signatures(symbol_heading) +        description = get_dd_description(symbol_heading) +    return _create_markdown(signature, description, symbol_data.url).replace("¶", "").strip() diff --git a/bot/exts/info/doc/_redis_cache.py b/bot/exts/info/doc/_redis_cache.py new file mode 100644 index 000000000..ad764816f --- /dev/null +++ b/bot/exts/info/doc/_redis_cache.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import datetime +from typing import Optional, TYPE_CHECKING + +from async_rediscache.types.base import RedisObject, namespace_lock +if TYPE_CHECKING: +    from ._cog import DocItem + +WEEK_SECONDS = datetime.timedelta(weeks=1).total_seconds() + + +class DocRedisCache(RedisObject): +    """Interface for redis functionality needed by the Doc cog.""" + +    def __init__(self, *args, **kwargs): +        super().__init__(*args, **kwargs) +        self._set_expires = set() + +    @namespace_lock +    async def set(self, item: DocItem, value: str) -> None: +        """ +        Set the Markdown `value` for the symbol `item`. + +        All keys from a single page are stored together, expiring a week after the first set. +        """ +        url_key = remove_suffix(item.relative_url_path, ".html") +        redis_key = f"{self.namespace}:{item.package}:{url_key}" +        needs_expire = False + +        with await self._get_pool_connection() as connection: +            if redis_key not in self._set_expires: +                # An expire is only set if the key didn't exist before. +                # If this is the first time setting values for this key check if it exists and add it to +                # `_set_expires` to prevent redundant checks for subsequent uses with items from the same page. +                self._set_expires.add(redis_key) +                needs_expire = not await connection.exists(redis_key) + +            await connection.hset(redis_key, item.symbol_id, value) +            if needs_expire: +                await connection.expire(redis_key, WEEK_SECONDS) + +    @namespace_lock +    async def get(self, item: DocItem) -> Optional[str]: +        """Return the Markdown content of the symbol `item` if it exists.""" +        url_key = remove_suffix(item.relative_url_path, ".html") + +        with await self._get_pool_connection() as connection: +            return await connection.hget(f"{self.namespace}:{item.package}:{url_key}", item.symbol_id, encoding="utf8") + +    @namespace_lock +    async def delete(self, package: str) -> bool: +        """Remove all values for `package`; return True if at least one key was deleted, False otherwise.""" +        with await self._get_pool_connection() as connection: +            package_keys = [ +                package_key async for package_key in connection.iscan(match=f"{self.namespace}:{package}:*") +            ] +            if package_keys: +                await connection.delete(*package_keys) +                return True +            return False + + +def remove_suffix(string: str, suffix: str) -> str: +    """Remove `suffix` from end of `string`.""" +    # TODO replace usages with str.removesuffix on 3.9 +    if string.endswith(suffix): +        return string[:-len(suffix)] +    else: +        return string diff --git a/bot/exts/info/information.py b/bot/exts/info/information.py index 5e2c4b417..834fee1b4 100644 --- a/bot/exts/info/information.py +++ b/bot/exts/info/information.py @@ -230,6 +230,11 @@ class Information(Cog):          if on_server and user.nick:              name = f"{user.nick} ({name})" +        if user.public_flags.verified_bot: +            name += f" {constants.Emojis.verified_bot}" +        elif user.bot: +            name += f" {constants.Emojis.bot}" +          badges = []          for badge, is_set in user.public_flags: diff --git a/bot/exts/info/source.py b/bot/exts/info/source.py index 49e74f204..ef07c77a1 100644 --- a/bot/exts/info/source.py +++ b/bot/exts/info/source.py @@ -14,9 +14,10 @@ SourceType = Union[commands.HelpCommand, commands.Command, commands.Cog, str, co  class SourceConverter(commands.Converter):      """Convert an argument into a help command, tag, command, or cog.""" -    async def convert(self, ctx: commands.Context, argument: str) -> SourceType: +    @staticmethod +    async def convert(ctx: commands.Context, argument: str) -> SourceType:          """Convert argument into source object.""" -        if argument.lower().startswith("help"): +        if argument.lower() == "help":              return ctx.bot.help_command          cog = ctx.bot.get_cog(argument) @@ -68,7 +69,8 @@ class BotSource(commands.Cog):          Raise BadArgument if `source_item` is a dynamically-created object (e.g. via internal eval).          """          if isinstance(source_item, commands.Command): -            src = source_item.callback.__code__ +            source_item = inspect.unwrap(source_item.callback) +            src = source_item.__code__              filename = src.co_filename          elif isinstance(source_item, str):              tags_cog = self.bot.get_cog("Tags") diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py index d89e80acc..38d1ffc0e 100644 --- a/bot/exts/moderation/infraction/infractions.py +++ b/bot/exts/moderation/infraction/infractions.py @@ -54,8 +54,12 @@ class Infractions(InfractionScheduler, commands.Cog):      # region: Permanent infractions      @command() -    async def warn(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None: +    async def warn(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None:          """Warn a user for the given reason.""" +        if not isinstance(user, Member): +            await ctx.send(":x: The user doesn't appear to be on the server.") +            return +          infraction = await _utils.post_infraction(ctx, user, "warning", reason, active=False)          if infraction is None:              return @@ -63,8 +67,12 @@ class Infractions(InfractionScheduler, commands.Cog):          await self.apply_infraction(ctx, infraction, user)      @command() -    async def kick(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None: +    async def kick(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None:          """Kick a user for the given reason.""" +        if not isinstance(user, Member): +            await ctx.send(":x: The user doesn't appear to be on the server.") +            return +          await self.apply_kick(ctx, user, reason)      @command() @@ -100,7 +108,7 @@ class Infractions(InfractionScheduler, commands.Cog):      @command(aliases=["mute"])      async def tempmute(          self, ctx: Context, -        user: Member, +        user: FetchedMember,          duration: t.Optional[Expiry] = None,          *,          reason: t.Optional[str] = None @@ -122,6 +130,10 @@ class Infractions(InfractionScheduler, commands.Cog):          If no duration is given, a one hour duration is used by default.          """ +        if not isinstance(user, Member): +            await ctx.send(":x: The user doesn't appear to be on the server.") +            return +          if duration is None:              duration = await Duration().convert(ctx, "1h")          await self.apply_mute(ctx, user, reason, expires_at=duration) diff --git a/bot/exts/moderation/infraction/superstarify.py b/bot/exts/moderation/infraction/superstarify.py index 704dddf9c..07e79b9fe 100644 --- a/bot/exts/moderation/infraction/superstarify.py +++ b/bot/exts/moderation/infraction/superstarify.py @@ -11,7 +11,7 @@ from discord.utils import escape_markdown  from bot import constants  from bot.bot import Bot -from bot.converters import Expiry +from bot.converters import Duration, Expiry  from bot.exts.moderation.infraction import _utils  from bot.exts.moderation.infraction._scheduler import InfractionScheduler  from bot.utils.messages import format_user @@ -19,6 +19,7 @@ from bot.utils.time import format_infraction  log = logging.getLogger(__name__)  NICKNAME_POLICY_URL = "https://pythondiscord.com/pages/rules/#nickname-policy" +SUPERSTARIFY_DEFAULT_DURATION = "1h"  with Path("bot/resources/stars.json").open(encoding="utf-8") as stars_file:      STAR_NAMES = json.load(stars_file) @@ -109,7 +110,7 @@ class Superstarify(InfractionScheduler, Cog):          self,          ctx: Context,          member: Member, -        duration: Expiry, +        duration: t.Optional[Expiry],          *,          reason: str = '',      ) -> None: @@ -134,6 +135,9 @@ class Superstarify(InfractionScheduler, Cog):          if await _utils.get_active_infraction(ctx, member, "superstar"):              return +        # Set to default duration if none was provided. +        duration = duration or await Duration().convert(ctx, SUPERSTARIFY_DEFAULT_DURATION) +          # Post the infraction to the API          old_nick = member.display_name          infraction_reason = f'Old nickname: {old_nick}. {reason}' diff --git a/bot/exts/moderation/modlog.py b/bot/exts/moderation/modlog.py index 2dae9d268..e92f76c9a 100644 --- a/bot/exts/moderation/modlog.py +++ b/bot/exts/moderation/modlog.py @@ -14,7 +14,7 @@ from discord.abc import GuildChannel  from discord.ext.commands import Cog, Context  from bot.bot import Bot -from bot.constants import Categories, Channels, Colours, Emojis, Event, Guild as GuildConstant, Icons, URLs +from bot.constants import Categories, Channels, Colours, Emojis, Event, Guild as GuildConstant, Icons, Roles, URLs  from bot.utils.messages import format_user  from bot.utils.time import humanize_delta @@ -115,9 +115,9 @@ class ModLog(Cog, name="ModLog"):          if ping_everyone:              if content: -                content = f"@everyone\n{content}" +                content = f"<@&{Roles.moderators}>\n{content}"              else: -                content = "@everyone" +                content = f"<@&{Roles.moderators}>"          # Truncate content to 2000 characters and append an ellipsis.          if content and len(content) > 2000: @@ -127,8 +127,7 @@ class ModLog(Cog, name="ModLog"):          log_message = await channel.send(              content=content,              embed=embed, -            files=files, -            allowed_mentions=discord.AllowedMentions(everyone=True) +            files=files          )          if additional_embeds: diff --git a/bot/exts/moderation/modpings.py b/bot/exts/moderation/modpings.py new file mode 100644 index 000000000..1ad5005de --- /dev/null +++ b/bot/exts/moderation/modpings.py @@ -0,0 +1,138 @@ +import datetime +import logging + +from async_rediscache import RedisCache +from dateutil.parser import isoparse +from discord import Embed, Member +from discord.ext.commands import Cog, Context, group, has_any_role + +from bot.bot import Bot +from bot.constants import Colours, Emojis, Guild, Icons, MODERATION_ROLES, Roles +from bot.converters import Expiry +from bot.utils.scheduling import Scheduler + +log = logging.getLogger(__name__) + + +class ModPings(Cog): +    """Commands for a moderator to turn moderator pings on and off.""" + +    # RedisCache[discord.Member.id, 'Naïve ISO 8601 string'] +    # The cache's keys are mods who have pings off. +    # The cache's values are the times when the role should be re-applied to them, stored in ISO format. +    pings_off_mods = RedisCache() + +    def __init__(self, bot: Bot): +        self.bot = bot +        self._role_scheduler = Scheduler(self.__class__.__name__) + +        self.guild = None +        self.moderators_role = None + +        self.reschedule_task = self.bot.loop.create_task(self.reschedule_roles(), name="mod-pings-reschedule") + +    async def reschedule_roles(self) -> None: +        """Reschedule moderators role re-apply times.""" +        await self.bot.wait_until_guild_available() +        self.guild = self.bot.get_guild(Guild.id) +        self.moderators_role = self.guild.get_role(Roles.moderators) + +        mod_team = self.guild.get_role(Roles.mod_team) +        pings_on = self.moderators_role.members +        pings_off = await self.pings_off_mods.to_dict() + +        log.trace("Applying the moderators role to the mod team where necessary.") +        for mod in mod_team.members: +            if mod in pings_on:  # Make sure that on-duty mods aren't in the cache. +                if mod in pings_off: +                    await self.pings_off_mods.delete(mod.id) +                continue + +            # Keep the role off only for those in the cache. +            if mod.id not in pings_off: +                await self.reapply_role(mod) +            else: +                expiry = isoparse(pings_off[mod.id]).replace(tzinfo=None) +                self._role_scheduler.schedule_at(expiry, mod.id, self.reapply_role(mod)) + +    async def reapply_role(self, mod: Member) -> None: +        """Reapply the moderator's role to the given moderator.""" +        log.trace(f"Re-applying role to mod with ID {mod.id}.") +        await mod.add_roles(self.moderators_role, reason="Pings off period expired.") + +    @group(name='modpings', aliases=('modping',), invoke_without_command=True) +    @has_any_role(*MODERATION_ROLES) +    async def modpings_group(self, ctx: Context) -> None: +        """Allow the removal and re-addition of the pingable moderators role.""" +        await ctx.send_help(ctx.command) + +    @modpings_group.command(name='off') +    @has_any_role(*MODERATION_ROLES) +    async def off_command(self, ctx: Context, duration: Expiry) -> None: +        """ +        Temporarily removes the pingable moderators role for a set amount of time. + +        A unit of time should be appended to the duration. +        Units (∗case-sensitive): +        \u2003`y` - years +        \u2003`m` - months∗ +        \u2003`w` - weeks +        \u2003`d` - days +        \u2003`h` - hours +        \u2003`M` - minutes∗ +        \u2003`s` - seconds + +        Alternatively, an ISO 8601 timestamp can be provided for the duration. + +        The duration cannot be longer than 30 days. +        """ +        duration: datetime.datetime +        delta = duration - datetime.datetime.utcnow() +        if delta > datetime.timedelta(days=30): +            await ctx.send(":x: Cannot remove the role for longer than 30 days.") +            return + +        mod = ctx.author + +        until_date = duration.replace(microsecond=0).isoformat()  # Looks noisy with microseconds. +        await mod.remove_roles(self.moderators_role, reason=f"Turned pings off until {until_date}.") + +        await self.pings_off_mods.set(mod.id, duration.isoformat()) + +        # Allow rescheduling the task without cancelling it separately via the `on` command. +        if mod.id in self._role_scheduler: +            self._role_scheduler.cancel(mod.id) +        self._role_scheduler.schedule_at(duration, mod.id, self.reapply_role(mod)) + +        embed = Embed(timestamp=duration, colour=Colours.bright_green) +        embed.set_footer(text="Moderators role has been removed until", icon_url=Icons.green_checkmark) +        await ctx.send(embed=embed) + +    @modpings_group.command(name='on') +    @has_any_role(*MODERATION_ROLES) +    async def on_command(self, ctx: Context) -> None: +        """Re-apply the pingable moderators role.""" +        mod = ctx.author +        if mod in self.moderators_role.members: +            await ctx.send(":question: You already have the role.") +            return + +        await mod.add_roles(self.moderators_role, reason="Pings off period canceled.") + +        await self.pings_off_mods.delete(mod.id) + +        # We assume the task exists. Lack of it may indicate a bug. +        self._role_scheduler.cancel(mod.id) + +        await ctx.send(f"{Emojis.check_mark} Moderators role has been re-applied.") + +    def cog_unload(self) -> None: +        """Cancel role tasks when the cog unloads.""" +        log.trace("Cog unload: canceling role tasks.") +        self.reschedule_task.cancel() +        self._role_scheduler.cancel_all() + + +def setup(bot: Bot) -> None: +    """Load the ModPings cog.""" +    bot.add_cog(ModPings(bot)) diff --git a/bot/exts/moderation/stream.py b/bot/exts/moderation/stream.py index 12e195172..fd856a7f4 100644 --- a/bot/exts/moderation/stream.py +++ b/bot/exts/moderation/stream.py @@ -1,5 +1,6 @@  import logging  from datetime import timedelta, timezone +from operator import itemgetter  import arrow  import discord @@ -8,8 +9,9 @@ from async_rediscache import RedisCache  from discord.ext import commands  from bot.bot import Bot -from bot.constants import Colours, Emojis, Guild, Roles, STAFF_ROLES, VideoPermission +from bot.constants import Colours, Emojis, Guild, MODERATION_ROLES, Roles, STAFF_ROLES, VideoPermission  from bot.converters import Expiry +from bot.pagination import LinePaginator  from bot.utils.scheduling import Scheduler  from bot.utils.time import format_infraction_with_duration @@ -68,8 +70,30 @@ class Stream(commands.Cog):                  self._revoke_streaming_permission(member)              ) +    async def _suspend_stream(self, ctx: commands.Context, member: discord.Member) -> None: +        """Suspend a member's stream.""" +        await self.bot.wait_until_guild_available() +        voice_state = member.voice + +        if not voice_state: +            return + +        # If the user is streaming. +        if voice_state.self_stream: +            # End user's stream by moving them to AFK voice channel and back. +            original_vc = voice_state.channel +            await member.move_to(ctx.guild.afk_channel) +            await member.move_to(original_vc) + +            # Notify. +            await ctx.send(f"{member.mention}'s stream has been suspended!") +            log.debug(f"Successfully suspended stream from {member} ({member.id}).") +            return + +        log.debug(f"No stream found to suspend from {member} ({member.id}).") +      @commands.command(aliases=("streaming",)) -    @commands.has_any_role(*STAFF_ROLES) +    @commands.has_any_role(*MODERATION_ROLES)      async def stream(self, ctx: commands.Context, member: discord.Member, duration: Expiry = None) -> None:          """          Temporarily grant streaming permissions to a member for a given duration. @@ -126,7 +150,7 @@ class Stream(commands.Cog):          log.debug(f"Successfully gave {member} ({member.id}) permission to stream until {revoke_time}.")      @commands.command(aliases=("pstream",)) -    @commands.has_any_role(*STAFF_ROLES) +    @commands.has_any_role(*MODERATION_ROLES)      async def permanentstream(self, ctx: commands.Context, member: discord.Member) -> None:          """Permanently grants the given member the permission to stream."""          log.trace(f"Attempting to give permanent streaming permission to {member} ({member.id}).") @@ -153,7 +177,7 @@ class Stream(commands.Cog):          log.debug(f"Successfully gave {member} ({member.id}) permanent streaming permission.")      @commands.command(aliases=("unstream", "rstream")) -    @commands.has_any_role(*STAFF_ROLES) +    @commands.has_any_role(*MODERATION_ROLES)      async def revokestream(self, ctx: commands.Context, member: discord.Member) -> None:          """Revoke the permission to stream from the given member."""          log.trace(f"Attempting to remove streaming permission from {member} ({member.id}).") @@ -168,10 +192,52 @@ class Stream(commands.Cog):              await ctx.send(f"{Emojis.check_mark} Revoked the permission to stream from {member.mention}.")              log.debug(f"Successfully revoked streaming permission from {member} ({member.id}).") -            return -        await ctx.send(f"{Emojis.cross_mark} This member doesn't have video permissions to remove!") -        log.debug(f"{member} ({member.id}) didn't have the streaming permission to remove!") +        else: +            await ctx.send(f"{Emojis.cross_mark} This member doesn't have video permissions to remove!") +            log.debug(f"{member} ({member.id}) didn't have the streaming permission to remove!") + +        await self._suspend_stream(ctx, member) + +    @commands.command(aliases=('lstream',)) +    @commands.has_any_role(*MODERATION_ROLES) +    async def liststream(self, ctx: commands.Context) -> None: +        """Lists all non-staff users who have permission to stream.""" +        non_staff_members_with_stream = [ +            member +            for member in ctx.guild.get_role(Roles.video).members +            if not any(role.id in STAFF_ROLES for role in member.roles) +        ] + +        # List of tuples (UtcPosixTimestamp, str) +        # So that the list can be sorted on the UtcPosixTimestamp before the message is passed to the paginator. +        streamer_info = [] +        for member in non_staff_members_with_stream: +            if revoke_time := await self.task_cache.get(member.id): +                # Member only has temporary streaming perms +                revoke_delta = Arrow.utcfromtimestamp(revoke_time).humanize() +                message = f"{member.mention} will have stream permissions revoked {revoke_delta}." +            else: +                message = f"{member.mention} has permanent streaming permissions." + +            # If revoke_time is None use max timestamp to force sort to put them at the end +            streamer_info.append( +                (revoke_time or Arrow.max.timestamp(), message) +            ) + +        if streamer_info: +            # Sort based on duration left of streaming perms +            streamer_info.sort(key=itemgetter(0)) + +            # Only output the message in the pagination +            lines = [line[1] for line in streamer_info] +            embed = discord.Embed( +                title=f"Members with streaming permission (`{len(lines)}` total)", +                colour=Colours.soft_green +            ) +            await LinePaginator.paginate(lines, ctx, embed, max_size=400, empty=False) +        else: +            await ctx.send("No members with stream permissions found.")  def setup(bot: Bot) -> None: diff --git a/bot/exts/utils/clean.py b/bot/exts/utils/clean.py index 8acaf9131..cb662e852 100644 --- a/bot/exts/utils/clean.py +++ b/bot/exts/utils/clean.py @@ -3,7 +3,7 @@ import random  import re  from typing import Iterable, Optional -from discord import Colour, Embed, Message, TextChannel, User +from discord import Colour, Embed, Message, TextChannel, User, errors  from discord.ext import commands  from discord.ext.commands import Cog, Context, group, has_any_role @@ -115,7 +115,11 @@ class Clean(Cog):          # Delete the invocation first          self.mod_log.ignore(Event.message_delete, ctx.message.id) -        await ctx.message.delete() +        try: +            await ctx.message.delete() +        except errors.NotFound: +            # Invocation message has already been deleted +            log.info("Tried to delete invocation message, but it was already deleted.")          messages = []          message_ids = [] diff --git a/bot/exts/utils/reminders.py b/bot/exts/utils/reminders.py index 3113a1149..6c21920a1 100644 --- a/bot/exts/utils/reminders.py +++ b/bot/exts/utils/reminders.py @@ -90,15 +90,18 @@ class Reminders(Cog):          delivery_dt: t.Optional[datetime],      ) -> None:          """Send an embed confirming the reminder change was made successfully.""" -        embed = discord.Embed() -        embed.colour = discord.Colour.green() -        embed.title = random.choice(POSITIVE_REPLIES) -        embed.description = on_success +        embed = discord.Embed( +            description=on_success, +            colour=discord.Colour.green(), +            title=random.choice(POSITIVE_REPLIES) +        )          footer_str = f"ID: {reminder_id}" +          if delivery_dt:              # Reminder deletion will have a `None` `delivery_dt` -            footer_str = f"{footer_str}, Due: {delivery_dt.strftime('%Y-%m-%dT%H:%M:%S')}" +            footer_str += ', Due' +            embed.timestamp = delivery_dt          embed.set_footer(text=footer_str) diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py index 9f480c067..da95240bb 100644 --- a/bot/exts/utils/snekbox.py +++ b/bot/exts/utils/snekbox.py @@ -13,7 +13,7 @@ from discord.ext.commands import Cog, Context, command, guild_only  from bot.bot import Bot  from bot.constants import Categories, Channels, Roles, URLs -from bot.decorators import in_whitelist +from bot.decorators import not_in_blacklist  from bot.utils import send_to_paste_service  from bot.utils.messages import wait_for_deletion @@ -38,9 +38,9 @@ RAW_CODE_REGEX = re.compile(  MAX_PASTE_LEN = 10000 -# `!eval` command whitelists -EVAL_CHANNELS = (Channels.bot_commands, Channels.esoteric) -EVAL_CATEGORIES = (Categories.help_available, Categories.help_in_use, Categories.voice) +# `!eval` command whitelists and blacklists. +NO_EVAL_CHANNELS = (Channels.python_general,) +NO_EVAL_CATEGORIES = ()  EVAL_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners)  SIGKILL = 9 @@ -280,7 +280,7 @@ class Snekbox(Cog):      @command(name="eval", aliases=("e",))      @guild_only() -    @in_whitelist(channels=EVAL_CHANNELS, categories=EVAL_CATEGORIES, roles=EVAL_ROLES) +    @not_in_blacklist(channels=NO_EVAL_CHANNELS, categories=NO_EVAL_CATEGORIES, override_roles=EVAL_ROLES)      async def eval_command(self, ctx: Context, *, code: str = None) -> None:          """          Run Python code and get the results. diff --git a/bot/exts/utils/utils.py b/bot/exts/utils/utils.py index cae7f2593..4c39a7c2a 100644 --- a/bot/exts/utils/utils.py +++ b/bot/exts/utils/utils.py @@ -109,7 +109,7 @@ class Utils(Cog):          # handle if it's an index int          if isinstance(search_value, int):              upper_bound = len(zen_lines) - 1 -            lower_bound = -1 * upper_bound +            lower_bound = -1 * len(zen_lines)              if not (lower_bound <= search_value <= upper_bound):                  raise BadArgument(f"Please provide an index between {lower_bound} and {upper_bound}.") @@ -162,17 +162,27 @@ class Utils(Cog):          if len(snowflakes) > 1 and await has_no_roles_check(ctx, *STAFF_ROLES):              raise BadArgument("Cannot process more than one snowflake in one invocation.") +        if not snowflakes: +            raise BadArgument("At least one snowflake must be provided.") + +        embed = Embed(colour=Colour.blue()) +        embed.set_author( +            name=f"Snowflake{'s'[:len(snowflakes)^1]}",  # Deals with pluralisation +            icon_url="https://github.com/twitter/twemoji/blob/master/assets/72x72/2744.png?raw=true" +        ) + +        lines = []          for snowflake in snowflakes:              created_at = snowflake_time(snowflake) -            embed = Embed( -                description=f"**Created at {created_at}** ({time_since(created_at, max_units=3)}).", -                colour=Colour.blue() -            ) -            embed.set_author( -                name=f"Snowflake: {snowflake}", -                icon_url="https://github.com/twitter/twemoji/blob/master/assets/72x72/2744.png?raw=true" -            ) -            await ctx.send(embed=embed) +            lines.append(f"**{snowflake}**\nCreated at {created_at} ({time_since(created_at, max_units=3)}).") + +        await LinePaginator.paginate( +            lines, +            ctx=ctx, +            embed=embed, +            max_lines=5, +            max_size=1000 +        )      @command(aliases=("poll",))      @has_any_role(*MODERATION_ROLES, Roles.project_leads, Roles.domain_leads) diff --git a/bot/log.py b/bot/log.py index e92233a33..4e20c005e 100644 --- a/bot/log.py +++ b/bot/log.py @@ -20,7 +20,6 @@ def setup() -> None:      logging.addLevelName(TRACE_LEVEL, "TRACE")      Logger.trace = _monkeypatch_trace -    log_level = TRACE_LEVEL if constants.DEBUG_MODE else logging.INFO      format_string = "%(asctime)s | %(name)s | %(levelname)s | %(message)s"      log_format = logging.Formatter(format_string) @@ -30,7 +29,6 @@ def setup() -> None:      file_handler.setFormatter(log_format)      root_log = logging.getLogger() -    root_log.setLevel(log_level)      root_log.addHandler(file_handler)      if "COLOREDLOGS_LEVEL_STYLES" not in os.environ: @@ -44,11 +42,9 @@ def setup() -> None:      if "COLOREDLOGS_LOG_FORMAT" not in os.environ:          coloredlogs.DEFAULT_LOG_FORMAT = format_string -    if "COLOREDLOGS_LOG_LEVEL" not in os.environ: -        coloredlogs.DEFAULT_LOG_LEVEL = log_level - -    coloredlogs.install(logger=root_log, stream=sys.stdout) +    coloredlogs.install(level=logging.TRACE, logger=root_log, stream=sys.stdout) +    root_log.setLevel(logging.DEBUG if constants.DEBUG_MODE else logging.INFO)      logging.getLogger("discord").setLevel(logging.WARNING)      logging.getLogger("websockets").setLevel(logging.WARNING)      logging.getLogger("chardet").setLevel(logging.WARNING) @@ -57,6 +53,8 @@ def setup() -> None:      # Set back to the default of INFO even if asyncio's debug mode is enabled.      logging.getLogger("asyncio").setLevel(logging.INFO) +    _set_trace_loggers() +  def setup_sentry() -> None:      """Set up the Sentry logging integrations.""" @@ -86,3 +84,30 @@ def _monkeypatch_trace(self: logging.Logger, msg: str, *args, **kwargs) -> None:      """      if self.isEnabledFor(TRACE_LEVEL):          self._log(TRACE_LEVEL, msg, args, **kwargs) + + +def _set_trace_loggers() -> None: +    """ +    Set loggers to the trace level according to the value from the BOT_TRACE_LOGGERS env var. + +    When the env var is a list of logger names delimited by a comma, +    each of the listed loggers will be set to the trace level. + +    If this list is prefixed with a "!", all of the loggers except the listed ones will be set to the trace level. + +    Otherwise if the env var begins with a "*", +    the root logger is set to the trace level and other contents are ignored. +    """ +    level_filter = constants.Bot.trace_loggers +    if level_filter: +        if level_filter.startswith("*"): +            logging.getLogger().setLevel(logging.TRACE) + +        elif level_filter.startswith("!"): +            logging.getLogger().setLevel(logging.TRACE) +            for logger_name in level_filter.strip("!,").split(","): +                logging.getLogger(logger_name).setLevel(logging.DEBUG) + +        else: +            for logger_name in level_filter.strip(",").split(","): +                logging.getLogger(logger_name).setLevel(logging.TRACE) diff --git a/bot/pagination.py b/bot/pagination.py index 3b16cc9ff..c5c84afd9 100644 --- a/bot/pagination.py +++ b/bot/pagination.py @@ -2,14 +2,14 @@ import asyncio  import logging  import typing as t  from contextlib import suppress +from functools import partial  import discord -from discord import Member  from discord.abc import User  from discord.ext.commands import Context, Paginator  from bot import constants -from bot.constants import MODERATION_ROLES +from bot.utils import messages  FIRST_EMOJI = "\u23EE"   # [:track_previous:]  LEFT_EMOJI = "\u2B05"    # [:arrow_left:] @@ -220,29 +220,6 @@ class LinePaginator(Paginator):          >>> embed.set_author(name="Some Operation", url=url, icon_url=icon)          >>> await LinePaginator.paginate([line for line in lines], ctx, embed)          """ -        def event_check(reaction_: discord.Reaction, user_: discord.Member) -> bool: -            """Make sure that this reaction is what we want to operate on.""" -            no_restrictions = ( -                # The reaction was by a whitelisted user -                user_.id == restrict_to_user.id -                # The reaction was by a moderator -                or isinstance(user_, Member) and any(role.id in MODERATION_ROLES for role in user_.roles) -            ) - -            return ( -                # Conditions for a successful pagination: -                all(( -                    # Reaction is on this message -                    reaction_.message.id == message.id, -                    # Reaction is one of the pagination emotes -                    str(reaction_.emoji) in PAGINATION_EMOJI, -                    # Reaction was not made by the Bot -                    user_.id != ctx.bot.user.id, -                    # There were no restrictions -                    no_restrictions -                )) -            ) -          paginator = cls(prefix=prefix, suffix=suffix, max_size=max_size, max_lines=max_lines,                          scale_to_size=scale_to_size)          current_page = 0 @@ -303,9 +280,16 @@ class LinePaginator(Paginator):              log.trace(f"Adding reaction: {repr(emoji)}")              await message.add_reaction(emoji) +        check = partial( +            messages.reaction_check, +            message_id=message.id, +            allowed_emoji=PAGINATION_EMOJI, +            allowed_users=(restrict_to_user.id,), +        ) +          while True:              try: -                reaction, user = await ctx.bot.wait_for("reaction_add", timeout=timeout, check=event_check) +                reaction, user = await ctx.bot.wait_for("reaction_add", timeout=timeout, check=check)                  log.trace(f"Got reaction: {reaction}")              except asyncio.TimeoutError:                  log.debug("Timed out waiting for a reaction") diff --git a/bot/resources/tags/customchecks.md b/bot/resources/tags/customchecks.md new file mode 100644 index 000000000..23ff7a66f --- /dev/null +++ b/bot/resources/tags/customchecks.md @@ -0,0 +1,21 @@ +**Custom Command Checks in discord.py** + +Often you may find the need to use checks that don't exist by default in discord.py. Fortunately, discord.py provides `discord.ext.commands.check` which allows you to create you own checks like this: +```py +from discord.ext.commands import check, Context + +def in_any_channel(*channels): +  async def predicate(ctx: Context): +    return ctx.channel.id in channels +  return check(predicate) +``` +This check is to check whether the invoked command is in a given set of channels. The inner function, named `predicate` here, is used to perform the actual check on the command, and check logic should go in this function. It must be an async function, and always provides a single `commands.Context` argument which you can use to create check logic. This check function should return a boolean value indicating whether the check passed (return `True`) or failed (return `False`). + +The check can now be used like any other commands check as a decorator of a command, such as this: +```py [email protected](name="ping") +@in_any_channel(728343273562701984) +async def ping(ctx: Context): +  ... +``` +This would lock the `ping` command to only be used in the channel `728343273562701984`. If this check function fails it will raise a `CheckFailure` exception, which can be handled in your error handler. diff --git a/bot/utils/checks.py b/bot/utils/checks.py index 460a937d8..3d0c8a50c 100644 --- a/bot/utils/checks.py +++ b/bot/utils/checks.py @@ -20,8 +20,8 @@ from bot import constants  log = logging.getLogger(__name__) -class InWhitelistCheckFailure(CheckFailure): -    """Raised when the `in_whitelist` check fails.""" +class ContextCheckFailure(CheckFailure): +    """Raised when a context-specific check fails."""      def __init__(self, redirect_channel: Optional[int]) -> None:          self.redirect_channel = redirect_channel @@ -36,6 +36,10 @@ class InWhitelistCheckFailure(CheckFailure):          super().__init__(error_message) +class InWhitelistCheckFailure(ContextCheckFailure): +    """Raised when the `in_whitelist` check fails.""" + +  def in_whitelist_check(      ctx: Context,      channels: Container[int] = (), diff --git a/bot/utils/function.py b/bot/utils/function.py index 3ab32fe3c..9bc44e753 100644 --- a/bot/utils/function.py +++ b/bot/utils/function.py @@ -1,14 +1,23 @@  """Utilities for interaction with functions.""" +import functools  import inspect +import logging +import types  import typing as t +log = logging.getLogger(__name__) +  Argument = t.Union[int, str]  BoundArgs = t.OrderedDict[str, t.Any]  Decorator = t.Callable[[t.Callable], t.Callable]  ArgValGetter = t.Callable[[BoundArgs], t.Any] +class GlobalNameConflictError(Exception): +    """Raised when there's a conflict between the globals used to resolve annotations of wrapped and its wrapper.""" + +  def get_arg_value(name_or_pos: Argument, arguments: BoundArgs) -> t.Any:      """      Return a value from `arguments` based on a name or position. @@ -73,3 +82,66 @@ def get_bound_args(func: t.Callable, args: t.Tuple, kwargs: t.Dict[str, t.Any])      bound_args.apply_defaults()      return bound_args.arguments + + +def update_wrapper_globals( +        wrapper: types.FunctionType, +        wrapped: types.FunctionType, +        *, +        ignored_conflict_names: t.Set[str] = frozenset(), +) -> types.FunctionType: +    """ +    Update globals of `wrapper` with the globals from `wrapped`. + +    For forwardrefs in command annotations discordpy uses the __global__ attribute of the function +    to resolve their values, with decorators that replace the function this breaks because they have +    their own globals. + +    This function creates a new function functionally identical to `wrapper`, which has the globals replaced with +    a merge of `wrapped`s globals and the `wrapper`s globals. + +    An exception will be raised in case `wrapper` and `wrapped` share a global name that is used by +    `wrapped`'s typehints and is not in `ignored_conflict_names`, +    as this can cause incorrect objects being used by discordpy's converters. +    """ +    annotation_global_names = ( +        ann.split(".", maxsplit=1)[0] for ann in wrapped.__annotations__.values() if isinstance(ann, str) +    ) +    # Conflicting globals from both functions' modules that are also used in the wrapper and in wrapped's annotations. +    shared_globals = set(wrapper.__code__.co_names) & set(annotation_global_names) +    shared_globals &= set(wrapped.__globals__) & set(wrapper.__globals__) - ignored_conflict_names +    if shared_globals: +        raise GlobalNameConflictError( +            f"wrapper and the wrapped function share the following " +            f"global names used by annotations: {', '.join(shared_globals)}. Resolve the conflicts or add " +            f"the name to the `ignored_conflict_names` set to suppress this error if this is intentional." +        ) + +    new_globals = wrapper.__globals__.copy() +    new_globals.update((k, v) for k, v in wrapped.__globals__.items() if k not in wrapper.__code__.co_names) +    return types.FunctionType( +        code=wrapper.__code__, +        globals=new_globals, +        name=wrapper.__name__, +        argdefs=wrapper.__defaults__, +        closure=wrapper.__closure__, +    ) + + +def command_wraps( +        wrapped: types.FunctionType, +        assigned: t.Sequence[str] = functools.WRAPPER_ASSIGNMENTS, +        updated: t.Sequence[str] = functools.WRAPPER_UPDATES, +        *, +        ignored_conflict_names: t.Set[str] = frozenset(), +) -> t.Callable[[types.FunctionType], types.FunctionType]: +    """Update the decorated function to look like `wrapped` and update globals for discordpy forwardref evaluation.""" +    def decorator(wrapper: types.FunctionType) -> types.FunctionType: +        return functools.update_wrapper( +            update_wrapper_globals(wrapper, wrapped, ignored_conflict_names=ignored_conflict_names), +            wrapped, +            assigned, +            updated, +        ) + +    return decorator diff --git a/bot/utils/lock.py b/bot/utils/lock.py index e44776340..ec6f92cd4 100644 --- a/bot/utils/lock.py +++ b/bot/utils/lock.py @@ -1,13 +1,15 @@  import asyncio  import inspect  import logging +import types  from collections import defaultdict -from functools import partial, wraps +from functools import partial  from typing import Any, Awaitable, Callable, Hashable, Union  from weakref import WeakValueDictionary  from bot.errors import LockedResourceError  from bot.utils import function +from bot.utils.function import command_wraps  log = logging.getLogger(__name__)  __lock_dicts = defaultdict(WeakValueDictionary) @@ -17,6 +19,35 @@ _IdCallable = Callable[[function.BoundArgs], _IdCallableReturn]  ResourceId = Union[Hashable, _IdCallable] +class SharedEvent: +    """ +    Context manager managing an internal event exposed through the wait coro. + +    While any code is executing in this context manager, the underlying event will not be set; +    when all of the holders finish the event will be set. +    """ + +    def __init__(self): +        self._active_count = 0 +        self._event = asyncio.Event() +        self._event.set() + +    def __enter__(self): +        """Increment the count of the active holders and clear the internal event.""" +        self._active_count += 1 +        self._event.clear() + +    def __exit__(self, _exc_type, _exc_val, _exc_tb):  # noqa: ANN001 +        """Decrement the count of the active holders; if 0 is reached set the internal event.""" +        self._active_count -= 1 +        if not self._active_count: +            self._event.set() + +    async def wait(self) -> None: +        """Wait for all active holders to exit.""" +        await self._event.wait() + +  def lock(      namespace: Hashable,      resource_id: ResourceId, @@ -41,10 +72,10 @@ def lock(      If decorating a command, this decorator must go before (below) the `command` decorator.      """ -    def decorator(func: Callable) -> Callable: +    def decorator(func: types.FunctionType) -> types.FunctionType:          name = func.__name__ -        @wraps(func) +        @command_wraps(func)          async def wrapper(*args, **kwargs) -> Any:              log.trace(f"{name}: mutually exclusive decorator called") diff --git a/bot/utils/messages.py b/bot/utils/messages.py index 077dd9569..2beead6af 100644 --- a/bot/utils/messages.py +++ b/bot/utils/messages.py @@ -3,6 +3,7 @@ import contextlib  import logging  import random  import re +from functools import partial  from io import BytesIO  from typing import List, Optional, Sequence, Union @@ -12,24 +13,66 @@ from discord.ext.commands import Context  import bot  from bot.constants import Emojis, MODERATION_ROLES, NEGATIVE_REPLIES +from bot.utils import scheduling  log = logging.getLogger(__name__) +def reaction_check( +    reaction: discord.Reaction, +    user: discord.abc.User, +    *, +    message_id: int, +    allowed_emoji: Sequence[str], +    allowed_users: Sequence[int], +    allow_mods: bool = True, +) -> bool: +    """ +    Check if a reaction's emoji and author are allowed and the message is `message_id`. + +    If the user is not allowed, remove the reaction. Ignore reactions made by the bot. +    If `allow_mods` is True, allow users with moderator roles even if they're not in `allowed_users`. +    """ +    right_reaction = ( +        user != bot.instance.user +        and reaction.message.id == message_id +        and str(reaction.emoji) in allowed_emoji +    ) +    if not right_reaction: +        return False + +    is_moderator = ( +        allow_mods +        and any(role.id in MODERATION_ROLES for role in getattr(user, "roles", [])) +    ) + +    if user.id in allowed_users or is_moderator: +        log.trace(f"Allowed reaction {reaction} by {user} on {reaction.message.id}.") +        return True +    else: +        log.trace(f"Removing reaction {reaction} by {user} on {reaction.message.id}: disallowed user.") +        scheduling.create_task( +            reaction.message.remove_reaction(reaction.emoji, user), +            HTTPException,  # Suppress the HTTPException if adding the reaction fails +            name=f"remove_reaction-{reaction}-{reaction.message.id}-{user}" +        ) +        return False + +  async def wait_for_deletion(      message: discord.Message, -    user_ids: Sequence[discord.abc.Snowflake], +    user_ids: Sequence[int],      deletion_emojis: Sequence[str] = (Emojis.trashcan,),      timeout: float = 60 * 5,      attach_emojis: bool = True, -    allow_moderation_roles: bool = True +    allow_mods: bool = True  ) -> None:      """      Wait for up to `timeout` seconds for a reaction by any of the specified `user_ids` to delete the message.      An `attach_emojis` bool may be specified to determine whether to attach the given      `deletion_emojis` to the message in the given `context`. -    An `allow_moderation_roles` bool may also be specified to allow anyone with a role in `MODERATION_ROLES` to delete +    An `allow_mods` bool may also be specified to allow anyone with a role in `MODERATION_ROLES` to delete      the message.      """      if message.guild is None: @@ -43,16 +86,13 @@ async def wait_for_deletion(                  log.trace(f"Aborting wait_for_deletion: message {message.id} deleted prematurely.")                  return -    def check(reaction: discord.Reaction, user: discord.Member) -> bool: -        """Check that the deletion emoji is reacted by the appropriate user.""" -        return ( -            reaction.message.id == message.id -            and str(reaction.emoji) in deletion_emojis -            and ( -                user.id in user_ids -                or allow_moderation_roles and any(role.id in MODERATION_ROLES for role in user.roles) -            ) -        ) +    check = partial( +        reaction_check, +        message_id=message.id, +        allowed_emoji=deletion_emojis, +        allowed_users=user_ids, +        allow_mods=allow_mods, +    )      with contextlib.suppress(asyncio.TimeoutError):          await bot.instance.wait_for('reaction_add', check=check, timeout=timeout) @@ -141,14 +181,14 @@ def sub_clyde(username: Optional[str]) -> Optional[str]:          return username  # Empty string or None -async def send_denial(ctx: Context, reason: str) -> None: +async def send_denial(ctx: Context, reason: str) -> discord.Message:      """Send an embed denying the user with the given reason."""      embed = discord.Embed()      embed.colour = discord.Colour.red()      embed.title = random.choice(NEGATIVE_REPLIES)      embed.description = reason -    await ctx.send(embed=embed) +    return await ctx.send(embed=embed)  def format_user(user: discord.abc.User) -> str: diff --git a/bot/utils/scheduling.py b/bot/utils/scheduling.py index 6843bae88..2dc485f24 100644 --- a/bot/utils/scheduling.py +++ b/bot/utils/scheduling.py @@ -161,18 +161,18 @@ class Scheduler:                  self._log.error(f"Error in task #{task_id} {id(done_task)}!", exc_info=exception) -def create_task(*args, **kwargs) -> asyncio.Task: +def create_task(coro: t.Awaitable, *suppressed_exceptions: t.Type[Exception], **kwargs) -> asyncio.Task:      """Wrapper for `asyncio.create_task` which logs exceptions raised in the task.""" -    task = asyncio.create_task(*args, **kwargs) -    task.add_done_callback(_log_task_exception) +    task = asyncio.create_task(coro, **kwargs) +    task.add_done_callback(partial(_log_task_exception, suppressed_exceptions=suppressed_exceptions))      return task -def _log_task_exception(task: asyncio.Task) -> None: +def _log_task_exception(task: asyncio.Task, *, suppressed_exceptions: t.Tuple[t.Type[Exception]]) -> None:      """Retrieve and log the exception raised in `task` if one exists."""      with contextlib.suppress(asyncio.CancelledError):          exception = task.exception()          # Log the exception if one exists. -        if exception: +        if exception and not isinstance(exception, suppressed_exceptions):              log = logging.getLogger(__name__)              log.error(f"Error in task {task.get_name()} {id(task)}!", exc_info=exception) diff --git a/config-default.yml b/config-default.yml index 8c6e18470..46475f845 100644 --- a/config-default.yml +++ b/config-default.yml @@ -1,7 +1,8 @@  bot: -    prefix:      "!" -    sentry_dsn:  !ENV "BOT_SENTRY_DSN" -    token:       !ENV "BOT_TOKEN" +    prefix:         "!" +    sentry_dsn:     !ENV "BOT_SENTRY_DSN" +    token:          !ENV "BOT_TOKEN" +    trace_loggers:  !ENV "BOT_TRACE_LOGGERS"      clean:          # Maximum number of messages to traverse for clean commands @@ -46,6 +47,8 @@ style:          badge_partner: "<:partner:748666453242413136>"          badge_staff: "<:discord_staff:743882896498098226>"          badge_verified_bot_developer: "<:verified_bot_dev:743882897299210310>" +        bot: "<:bot:812712599464443914>" +        verified_bot: "<:verified_bot:811645219220750347>"          defcon_shutdown:    "<:defcondisabled:470326273952972810>"          defcon_unshutdown:  "<:defconenabled:470326274213150730>" @@ -260,7 +263,8 @@ guild:          devops:                             409416496733880320          domain_leads:                       807415650778742785          helpers:            &HELPERS_ROLE   267630620367257601 -        moderators:         &MODS_ROLE      267629731250176001 +        moderators:         &MODS_ROLE      831776746206265384 +        mod_team:           &MOD_TEAM_ROLE  267629731250176001          owners:             &OWNERS_ROLE    267627879762755584          project_leads:                      815701647526330398 @@ -273,13 +277,14 @@ guild:      moderation_roles:          - *ADMINS_ROLE +        - *MOD_TEAM_ROLE          - *MODS_ROLE          - *OWNERS_ROLE      staff_roles:          - *ADMINS_ROLE          - *HELPERS_ROLE -        - *MODS_ROLE +        - *MOD_TEAM_ROLE          - *OWNERS_ROLE      webhooks: diff --git a/tests/README.md b/tests/README.md index 4f62edd68..092324123 100644 --- a/tests/README.md +++ b/tests/README.md @@ -114,7 +114,7 @@ class BotCogTests(unittest.TestCase):  ### Mocking coroutines -By default, the `unittest.mock.Mock` and `unittest.mock.MagicMock` classes cannot mock coroutines, since the `__call__` method they provide is synchronous. In anticipation of the `AsyncMock` that will be [introduced in Python 3.8](https://docs.python.org/3.9/whatsnew/3.8.html#unittest), we have added an `AsyncMock` helper to [`helpers.py`](/tests/helpers.py). Do note that this drop-in replacement only implements an asynchronous `__call__` method, not the additional assertions that will come with the new `AsyncMock` type in Python 3.8. +By default, the `unittest.mock.Mock` and `unittest.mock.MagicMock` classes cannot mock coroutines, since the `__call__` method they provide is synchronous. The [`AsyncMock`](https://docs.python.org/3/library/unittest.mock.html#unittest.mock.AsyncMock) that has been [introduced in Python 3.8](https://docs.python.org/3.9/whatsnew/3.8.html#unittest) is an asynchronous version of `MagicMock` that can be used anywhere a coroutine is expected.  ### Special mocks for some `discord.py` types diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py new file mode 100644 index 000000000..bd4fb5942 --- /dev/null +++ b/tests/bot/exts/backend/test_error_handler.py @@ -0,0 +1,550 @@ +import unittest +from unittest.mock import AsyncMock, MagicMock, call, patch + +from discord.ext.commands import errors + +from bot.api import ResponseCodeError +from bot.errors import InvalidInfractedUser, LockedResourceError +from bot.exts.backend.error_handler import ErrorHandler, setup +from bot.exts.info.tags import Tags +from bot.exts.moderation.silence import Silence +from bot.utils.checks import InWhitelistCheckFailure +from tests.helpers import MockBot, MockContext, MockGuild, MockRole + + +class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): +    """Tests for error handler functionality.""" + +    def setUp(self): +        self.bot = MockBot() +        self.ctx = MockContext(bot=self.bot) + +    async def test_error_handler_already_handled(self): +        """Should not do anything when error is already handled by local error handler.""" +        self.ctx.reset_mock() +        cog = ErrorHandler(self.bot) +        error = errors.CommandError() +        error.handled = "foo" +        self.assertIsNone(await cog.on_command_error(self.ctx, error)) +        self.ctx.send.assert_not_awaited() + +    async def test_error_handler_command_not_found_error_not_invoked_by_handler(self): +        """Should try first (un)silence channel, when fail, try to get tag.""" +        error = errors.CommandNotFound() +        test_cases = ( +            { +                "try_silence_return": True, +                "called_try_get_tag": False +            }, +            { +                "try_silence_return": False, +                "called_try_get_tag": False +            }, +            { +                "try_silence_return": False, +                "called_try_get_tag": True +            } +        ) +        cog = ErrorHandler(self.bot) +        cog.try_silence = AsyncMock() +        cog.try_get_tag = AsyncMock() + +        for case in test_cases: +            with self.subTest(try_silence_return=case["try_silence_return"], try_get_tag=case["called_try_get_tag"]): +                self.ctx.reset_mock() +                cog.try_silence.reset_mock(return_value=True) +                cog.try_get_tag.reset_mock() + +                cog.try_silence.return_value = case["try_silence_return"] +                self.ctx.channel.id = 1234 + +                self.assertIsNone(await cog.on_command_error(self.ctx, error)) + +                if case["try_silence_return"]: +                    cog.try_get_tag.assert_not_awaited() +                    cog.try_silence.assert_awaited_once() +                else: +                    cog.try_silence.assert_awaited_once() +                    cog.try_get_tag.assert_awaited_once() + +                self.ctx.send.assert_not_awaited() + +    async def test_error_handler_command_not_found_error_invoked_by_handler(self): +        """Should do nothing when error is `CommandNotFound` and have attribute `invoked_from_error_handler`.""" +        ctx = MockContext(bot=self.bot, invoked_from_error_handler=True) + +        cog = ErrorHandler(self.bot) +        cog.try_silence = AsyncMock() +        cog.try_get_tag = AsyncMock() + +        error = errors.CommandNotFound() + +        self.assertIsNone(await cog.on_command_error(ctx, error)) + +        cog.try_silence.assert_not_awaited() +        cog.try_get_tag.assert_not_awaited() +        self.ctx.send.assert_not_awaited() + +    async def test_error_handler_user_input_error(self): +        """Should await `ErrorHandler.handle_user_input_error` when error is `UserInputError`.""" +        self.ctx.reset_mock() +        cog = ErrorHandler(self.bot) +        cog.handle_user_input_error = AsyncMock() +        error = errors.UserInputError() +        self.assertIsNone(await cog.on_command_error(self.ctx, error)) +        cog.handle_user_input_error.assert_awaited_once_with(self.ctx, error) + +    async def test_error_handler_check_failure(self): +        """Should await `ErrorHandler.handle_check_failure` when error is `CheckFailure`.""" +        self.ctx.reset_mock() +        cog = ErrorHandler(self.bot) +        cog.handle_check_failure = AsyncMock() +        error = errors.CheckFailure() +        self.assertIsNone(await cog.on_command_error(self.ctx, error)) +        cog.handle_check_failure.assert_awaited_once_with(self.ctx, error) + +    async def test_error_handler_command_on_cooldown(self): +        """Should send error with `ctx.send` when error is `CommandOnCooldown`.""" +        self.ctx.reset_mock() +        cog = ErrorHandler(self.bot) +        error = errors.CommandOnCooldown(10, 9) +        self.assertIsNone(await cog.on_command_error(self.ctx, error)) +        self.ctx.send.assert_awaited_once_with(error) + +    async def test_error_handler_command_invoke_error(self): +        """Should call `handle_api_error` or `handle_unexpected_error` depending on original error.""" +        cog = ErrorHandler(self.bot) +        cog.handle_api_error = AsyncMock() +        cog.handle_unexpected_error = AsyncMock() +        test_cases = ( +            { +                "args": (self.ctx, errors.CommandInvokeError(ResponseCodeError(AsyncMock()))), +                "expect_mock_call": cog.handle_api_error +            }, +            { +                "args": (self.ctx, errors.CommandInvokeError(TypeError)), +                "expect_mock_call": cog.handle_unexpected_error +            }, +            { +                "args": (self.ctx, errors.CommandInvokeError(LockedResourceError("abc", "test"))), +                "expect_mock_call": "send" +            }, +            { +                "args": (self.ctx, errors.CommandInvokeError(InvalidInfractedUser(self.ctx.author))), +                "expect_mock_call": "send" +            } +        ) + +        for case in test_cases: +            with self.subTest(args=case["args"], expect_mock_call=case["expect_mock_call"]): +                self.ctx.send.reset_mock() +                self.assertIsNone(await cog.on_command_error(*case["args"])) +                if case["expect_mock_call"] == "send": +                    self.ctx.send.assert_awaited_once() +                else: +                    case["expect_mock_call"].assert_awaited_once_with( +                        self.ctx, case["args"][1].original +                    ) + +    async def test_error_handler_conversion_error(self): +        """Should call `handle_api_error` or `handle_unexpected_error` depending on original error.""" +        cog = ErrorHandler(self.bot) +        cog.handle_api_error = AsyncMock() +        cog.handle_unexpected_error = AsyncMock() +        cases = ( +            { +                "error": errors.ConversionError(AsyncMock(), ResponseCodeError(AsyncMock())), +                "mock_function_to_call": cog.handle_api_error +            }, +            { +                "error": errors.ConversionError(AsyncMock(), TypeError), +                "mock_function_to_call": cog.handle_unexpected_error +            } +        ) + +        for case in cases: +            with self.subTest(**case): +                self.assertIsNone(await cog.on_command_error(self.ctx, case["error"])) +                case["mock_function_to_call"].assert_awaited_once_with(self.ctx, case["error"].original) + +    async def test_error_handler_two_other_errors(self): +        """Should call `handle_unexpected_error` if error is `MaxConcurrencyReached` or `ExtensionError`.""" +        cog = ErrorHandler(self.bot) +        cog.handle_unexpected_error = AsyncMock() +        errs = ( +            errors.MaxConcurrencyReached(1, MagicMock()), +            errors.ExtensionError(name="foo") +        ) + +        for err in errs: +            with self.subTest(error=err): +                cog.handle_unexpected_error.reset_mock() +                self.assertIsNone(await cog.on_command_error(self.ctx, err)) +                cog.handle_unexpected_error.assert_awaited_once_with(self.ctx, err) + +    @patch("bot.exts.backend.error_handler.log") +    async def test_error_handler_other_errors(self, log_mock): +        """Should `log.debug` other errors.""" +        cog = ErrorHandler(self.bot) +        error = errors.DisabledCommand()  # Use this just as a other error +        self.assertIsNone(await cog.on_command_error(self.ctx, error)) +        log_mock.debug.assert_called_once() + + +class TrySilenceTests(unittest.IsolatedAsyncioTestCase): +    """Test for helper functions that handle `CommandNotFound` error.""" + +    def setUp(self): +        self.bot = MockBot() +        self.silence = Silence(self.bot) +        self.bot.get_command.return_value = self.silence.silence +        self.ctx = MockContext(bot=self.bot) +        self.cog = ErrorHandler(self.bot) + +    async def test_try_silence_context_invoked_from_error_handler(self): +        """Should set `Context.invoked_from_error_handler` to `True`.""" +        self.ctx.invoked_with = "foo" +        await self.cog.try_silence(self.ctx) +        self.assertTrue(hasattr(self.ctx, "invoked_from_error_handler")) +        self.assertTrue(self.ctx.invoked_from_error_handler) + +    async def test_try_silence_get_command(self): +        """Should call `get_command` with `silence`.""" +        self.ctx.invoked_with = "foo" +        await self.cog.try_silence(self.ctx) +        self.bot.get_command.assert_called_once_with("silence") + +    async def test_try_silence_no_permissions_to_run(self): +        """Should return `False` because missing permissions.""" +        self.ctx.invoked_with = "foo" +        self.bot.get_command.return_value.can_run = AsyncMock(return_value=False) +        self.assertFalse(await self.cog.try_silence(self.ctx)) + +    async def test_try_silence_no_permissions_to_run_command_error(self): +        """Should return `False` because `CommandError` raised (no permissions).""" +        self.ctx.invoked_with = "foo" +        self.bot.get_command.return_value.can_run = AsyncMock(side_effect=errors.CommandError()) +        self.assertFalse(await self.cog.try_silence(self.ctx)) + +    async def test_try_silence_silencing(self): +        """Should run silence command with correct arguments.""" +        self.bot.get_command.return_value.can_run = AsyncMock(return_value=True) +        test_cases = ("shh", "shhh", "shhhhhh", "shhhhhhhhhhhhhhhhhhh") + +        for case in test_cases: +            with self.subTest(message=case): +                self.ctx.reset_mock() +                self.ctx.invoked_with = case +                self.assertTrue(await self.cog.try_silence(self.ctx)) +                self.ctx.invoke.assert_awaited_once_with( +                    self.bot.get_command.return_value, +                    duration=min(case.count("h")*2, 15) +                ) + +    async def test_try_silence_unsilence(self): +        """Should call unsilence command.""" +        self.silence.silence.can_run = AsyncMock(return_value=True) +        test_cases = ("unshh", "unshhhhh", "unshhhhhhhhh") + +        for case in test_cases: +            with self.subTest(message=case): +                self.bot.get_command.side_effect = (self.silence.silence, self.silence.unsilence) +                self.ctx.reset_mock() +                self.ctx.invoked_with = case +                self.assertTrue(await self.cog.try_silence(self.ctx)) +                self.ctx.invoke.assert_awaited_once_with(self.silence.unsilence) + +    async def test_try_silence_no_match(self): +        """Should return `False` when message don't match.""" +        self.ctx.invoked_with = "foo" +        self.assertFalse(await self.cog.try_silence(self.ctx)) + + +class TryGetTagTests(unittest.IsolatedAsyncioTestCase): +    """Tests for `try_get_tag` function.""" + +    def setUp(self): +        self.bot = MockBot() +        self.ctx = MockContext() +        self.tag = Tags(self.bot) +        self.cog = ErrorHandler(self.bot) +        self.bot.get_command.return_value = self.tag.get_command + +    async def test_try_get_tag_get_command(self): +        """Should call `Bot.get_command` with `tags get` argument.""" +        self.bot.get_command.reset_mock() +        self.ctx.invoked_with = "foo" +        await self.cog.try_get_tag(self.ctx) +        self.bot.get_command.assert_called_once_with("tags get") + +    async def test_try_get_tag_invoked_from_error_handler(self): +        """`self.ctx` should have `invoked_from_error_handler` `True`.""" +        self.ctx.invoked_from_error_handler = False +        self.ctx.invoked_with = "foo" +        await self.cog.try_get_tag(self.ctx) +        self.assertTrue(self.ctx.invoked_from_error_handler) + +    async def test_try_get_tag_no_permissions(self): +        """Test how to handle checks failing.""" +        self.tag.get_command.can_run = AsyncMock(return_value=False) +        self.ctx.invoked_with = "foo" +        self.assertIsNone(await self.cog.try_get_tag(self.ctx)) + +    async def test_try_get_tag_command_error(self): +        """Should call `on_command_error` when `CommandError` raised.""" +        err = errors.CommandError() +        self.tag.get_command.can_run = AsyncMock(side_effect=err) +        self.cog.on_command_error = AsyncMock() +        self.ctx.invoked_with = "foo" +        self.assertIsNone(await self.cog.try_get_tag(self.ctx)) +        self.cog.on_command_error.assert_awaited_once_with(self.ctx, err) + +    @patch("bot.exts.backend.error_handler.TagNameConverter") +    async def test_try_get_tag_convert_success(self, tag_converter): +        """Converting tag should successful.""" +        self.ctx.invoked_with = "foo" +        tag_converter.convert = AsyncMock(return_value="foo") +        self.assertIsNone(await self.cog.try_get_tag(self.ctx)) +        tag_converter.convert.assert_awaited_once_with(self.ctx, "foo") +        self.ctx.invoke.assert_awaited_once() + +    @patch("bot.exts.backend.error_handler.TagNameConverter") +    async def test_try_get_tag_convert_fail(self, tag_converter): +        """Converting tag should raise `BadArgument`.""" +        self.ctx.reset_mock() +        self.ctx.invoked_with = "bar" +        tag_converter.convert = AsyncMock(side_effect=errors.BadArgument()) +        self.assertIsNone(await self.cog.try_get_tag(self.ctx)) +        self.ctx.invoke.assert_not_awaited() + +    async def test_try_get_tag_ctx_invoke(self): +        """Should call `ctx.invoke` with proper args/kwargs.""" +        self.ctx.reset_mock() +        self.ctx.invoked_with = "foo" +        self.assertIsNone(await self.cog.try_get_tag(self.ctx)) +        self.ctx.invoke.assert_awaited_once_with(self.tag.get_command, tag_name="foo") + +    async def test_dont_call_suggestion_tag_sent(self): +        """Should never call command suggestion if tag is already sent.""" +        self.ctx.invoked_with = "foo" +        self.ctx.invoke = AsyncMock(return_value=True) +        self.cog.send_command_suggestion = AsyncMock() + +        await self.cog.try_get_tag(self.ctx) +        self.cog.send_command_suggestion.assert_not_awaited() + +    @patch("bot.exts.backend.error_handler.MODERATION_ROLES", new=[1234]) +    async def test_dont_call_suggestion_if_user_mod(self): +        """Should not call command suggestion if user is a mod.""" +        self.ctx.invoked_with = "foo" +        self.ctx.invoke = AsyncMock(return_value=False) +        self.ctx.author.roles = [MockRole(id=1234)] +        self.cog.send_command_suggestion = AsyncMock() + +        await self.cog.try_get_tag(self.ctx) +        self.cog.send_command_suggestion.assert_not_awaited() + +    async def test_call_suggestion(self): +        """Should call command suggestion if user is not a mod.""" +        self.ctx.invoked_with = "foo" +        self.ctx.invoke = AsyncMock(return_value=False) +        self.cog.send_command_suggestion = AsyncMock() + +        await self.cog.try_get_tag(self.ctx) +        self.cog.send_command_suggestion.assert_awaited_once_with(self.ctx, "foo") + + +class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase): +    """Individual error categories handler tests.""" + +    def setUp(self): +        self.bot = MockBot() +        self.ctx = MockContext(bot=self.bot) +        self.cog = ErrorHandler(self.bot) + +    async def test_handle_input_error_handler_errors(self): +        """Should handle each error probably.""" +        test_cases = ( +            { +                "error": errors.MissingRequiredArgument(MagicMock()), +                "call_prepared": True +            }, +            { +                "error": errors.TooManyArguments(), +                "call_prepared": True +            }, +            { +                "error": errors.BadArgument(), +                "call_prepared": True +            }, +            { +                "error": errors.BadUnionArgument(MagicMock(), MagicMock(), MagicMock()), +                "call_prepared": True +            }, +            { +                "error": errors.ArgumentParsingError(), +                "call_prepared": False +            }, +            { +                "error": errors.UserInputError(), +                "call_prepared": True +            } +        ) + +        for case in test_cases: +            with self.subTest(error=case["error"], call_prepared=case["call_prepared"]): +                self.ctx.reset_mock() +                self.assertIsNone(await self.cog.handle_user_input_error(self.ctx, case["error"])) +                self.ctx.send.assert_awaited_once() +                if case["call_prepared"]: +                    self.ctx.send_help.assert_awaited_once() +                else: +                    self.ctx.send_help.assert_not_awaited() + +    async def test_handle_check_failure_errors(self): +        """Should await `ctx.send` when error is check failure.""" +        test_cases = ( +            { +                "error": errors.BotMissingPermissions(MagicMock()), +                "call_ctx_send": True +            }, +            { +                "error": errors.BotMissingRole(MagicMock()), +                "call_ctx_send": True +            }, +            { +                "error": errors.BotMissingAnyRole(MagicMock()), +                "call_ctx_send": True +            }, +            { +                "error": errors.NoPrivateMessage(), +                "call_ctx_send": True +            }, +            { +                "error": InWhitelistCheckFailure(1234), +                "call_ctx_send": True +            }, +            { +                "error": ResponseCodeError(MagicMock()), +                "call_ctx_send": False +            } +        ) + +        for case in test_cases: +            with self.subTest(error=case["error"], call_ctx_send=case["call_ctx_send"]): +                self.ctx.reset_mock() +                await self.cog.handle_check_failure(self.ctx, case["error"]) +                if case["call_ctx_send"]: +                    self.ctx.send.assert_awaited_once() +                else: +                    self.ctx.send.assert_not_awaited() + +    @patch("bot.exts.backend.error_handler.log") +    async def test_handle_api_error(self, log_mock): +        """Should `ctx.send` on HTTP error codes, `log.debug|warning` depends on code.""" +        test_cases = ( +            { +                "error": ResponseCodeError(AsyncMock(status=400)), +                "log_level": "debug" +            }, +            { +                "error": ResponseCodeError(AsyncMock(status=404)), +                "log_level": "debug" +            }, +            { +                "error": ResponseCodeError(AsyncMock(status=550)), +                "log_level": "warning" +            }, +            { +                "error": ResponseCodeError(AsyncMock(status=1000)), +                "log_level": "warning" +            } +        ) + +        for case in test_cases: +            with self.subTest(error=case["error"], log_level=case["log_level"]): +                self.ctx.reset_mock() +                log_mock.reset_mock() +                await self.cog.handle_api_error(self.ctx, case["error"]) +                self.ctx.send.assert_awaited_once() +                if case["log_level"] == "warning": +                    log_mock.warning.assert_called_once() +                else: +                    log_mock.debug.assert_called_once() + +    @patch("bot.exts.backend.error_handler.push_scope") +    @patch("bot.exts.backend.error_handler.log") +    async def test_handle_unexpected_error(self, log_mock, push_scope_mock): +        """Should `ctx.send` this error, error log this and sent to Sentry.""" +        for case in (None, MockGuild()): +            with self.subTest(guild=case): +                self.ctx.reset_mock() +                log_mock.reset_mock() +                push_scope_mock.reset_mock() + +                self.ctx.guild = case +                await self.cog.handle_unexpected_error(self.ctx, errors.CommandError()) + +                self.ctx.send.assert_awaited_once() +                log_mock.error.assert_called_once() +                push_scope_mock.assert_called_once() + +                set_tag_calls = [ +                    call("command", self.ctx.command.qualified_name), +                    call("message_id", self.ctx.message.id), +                    call("channel_id", self.ctx.channel.id), +                ] +                set_extra_calls = [ +                    call("full_message", self.ctx.message.content) +                ] +                if case: +                    url = ( +                        f"https://discordapp.com/channels/" +                        f"{self.ctx.guild.id}/{self.ctx.channel.id}/{self.ctx.message.id}" +                    ) +                    set_extra_calls.append(call("jump_to", url)) + +                push_scope_mock.set_tag.has_calls(set_tag_calls) +                push_scope_mock.set_extra.has_calls(set_extra_calls) + + +class OtherErrorHandlerTests(unittest.IsolatedAsyncioTestCase): +    """Other `ErrorHandler` tests.""" + +    def setUp(self): +        self.bot = MockBot() +        self.ctx = MockContext() + +    async def test_get_help_command_command_specified(self): +        """Should return coroutine of help command of specified command.""" +        self.ctx.command = "foo" +        result = ErrorHandler.get_help_command(self.ctx) +        expected = self.ctx.send_help("foo") +        self.assertEqual(result.__qualname__, expected.__qualname__) +        self.assertEqual(result.cr_frame.f_locals, expected.cr_frame.f_locals) + +        # Await coroutines to avoid warnings +        await result +        await expected + +    async def test_get_help_command_no_command_specified(self): +        """Should return coroutine of help command.""" +        self.ctx.command = None +        result = ErrorHandler.get_help_command(self.ctx) +        expected = self.ctx.send_help() +        self.assertEqual(result.__qualname__, expected.__qualname__) +        self.assertEqual(result.cr_frame.f_locals, expected.cr_frame.f_locals) + +        # Await coroutines to avoid warnings +        await result +        await expected + + +class ErrorHandlerSetupTests(unittest.TestCase): +    """Tests for `ErrorHandler` `setup` function.""" + +    def test_setup(self): +        """Should call `bot.add_cog` with `ErrorHandler`.""" +        bot = MockBot() +        setup(bot) +        bot.add_cog.assert_called_once() diff --git a/tests/bot/exts/info/doc/__init__.py b/tests/bot/exts/info/doc/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/tests/bot/exts/info/doc/__init__.py diff --git a/tests/bot/exts/info/doc/test_parsing.py b/tests/bot/exts/info/doc/test_parsing.py new file mode 100644 index 000000000..1663d8491 --- /dev/null +++ b/tests/bot/exts/info/doc/test_parsing.py @@ -0,0 +1,66 @@ +from unittest import TestCase + +from bot.exts.info.doc import _parsing as parsing + + +class SignatureSplitter(TestCase): + +    def test_basic_split(self): +        test_cases = ( +            ("0,0,0", ["0", "0", "0"]), +            ("0,a=0,a=0", ["0", "a=0", "a=0"]), +        ) +        self._run_tests(test_cases) + +    def test_commas_ignored_in_brackets(self): +        test_cases = ( +            ("0,[0,0],0,[0,0],0", ["0", "[0,0]", "0", "[0,0]", "0"]), +            ("(0,),0,(0,(0,),0),0", ["(0,)", "0", "(0,(0,),0)", "0"]), +        ) +        self._run_tests(test_cases) + +    def test_mixed_brackets(self): +        tests_cases = ( +            ("[0,{0},0],0,{0:0},0", ["[0,{0},0]", "0", "{0:0}", "0"]), +            ("([0],0,0),0,(0,0),0", ["([0],0,0)", "0", "(0,0)", "0"]), +            ("([(0,),(0,)],0),0", ["([(0,),(0,)],0)", "0"]), +        ) +        self._run_tests(tests_cases) + +    def test_string_contents_ignored(self): +        test_cases = ( +            ("'0,0',0,',',0", ["'0,0'", "0", "','", "0"]), +            ("0,[']',0],0", ["0", "[']',0]", "0"]), +            ("{0,0,'}}',0,'{'},0", ["{0,0,'}}',0,'{'}", "0"]), +        ) +        self._run_tests(test_cases) + +    def test_mixed_quotes(self): +        test_cases = ( +            ("\"0',0',\",'0,0',0", ["\"0',0',\"", "'0,0'", "0"]), +            ("\",',\",'\",',0", ["\",',\"", "'\",'", "0"]), +        ) +        self._run_tests(test_cases) + +    def test_quote_escaped(self): +        test_cases = ( +            (r"'\',','\\',0", [r"'\','", r"'\\'", "0"]), +            (r"'0\',0\\\'\\',0", [r"'0\',0\\\'\\'", "0"]), +        ) +        self._run_tests(test_cases) + +    def test_real_signatures(self): +        test_cases = ( +            ("start, stop[, step]", ["start", " stop[, step]"]), +            ("object=b'', encoding='utf-8', errors='strict'", ["object=b''", " encoding='utf-8'", " errors='strict'"]), +            ( +                "typename, field_names, *, rename=False, defaults=None, module=None", +                ["typename", " field_names", " *", " rename=False", " defaults=None", " module=None"] +            ), +        ) +        self._run_tests(test_cases) + +    def _run_tests(self, test_cases): +        for input_string, expected_output in test_cases: +            with self.subTest(input_string=input_string): +                self.assertEqual(list(parsing._split_parameters(input_string)), expected_output) diff --git a/tests/bot/exts/info/test_information.py b/tests/bot/exts/info/test_information.py index a996ce477..770660fe3 100644 --- a/tests/bot/exts/info/test_information.py +++ b/tests/bot/exts/info/test_information.py @@ -281,6 +281,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase):          """The embed should use the string representation of the user if they don't have a nick."""          ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1))          user = helpers.MockMember() +        user.public_flags = unittest.mock.MagicMock(verified_bot=False)          user.nick = None          user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock")          user.colour = 0 @@ -297,6 +298,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase):          """The embed should use the nick if it's available."""          ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1))          user = helpers.MockMember() +        user.public_flags = unittest.mock.MagicMock(verified_bot=False)          user.nick = "Cat lover"          user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock")          user.colour = 0 diff --git a/tests/bot/test_converters.py b/tests/bot/test_converters.py index c42111f3f..4af84dde5 100644 --- a/tests/bot/test_converters.py +++ b/tests/bot/test_converters.py @@ -10,9 +10,9 @@ from bot.converters import (      Duration,      HushDurationConverter,      ISODateTime, +    PackageName,      TagContentConverter,      TagNameConverter, -    ValidPythonIdentifier,  ) @@ -78,24 +78,23 @@ class ConverterTests(unittest.IsolatedAsyncioTestCase):                  with self.assertRaisesRegex(BadArgument, re.escape(exception_message)):                      await TagNameConverter.convert(self.context, invalid_name) -    async def test_valid_python_identifier_for_valid(self): -        """ValidPythonIdentifier returns valid identifiers unchanged.""" -        test_values = ('foo', 'lemon') +    async def test_package_name_for_valid(self): +        """PackageName returns valid package names unchanged.""" +        test_values = ('foo', 'le_mon', 'num83r')          for name in test_values:              with self.subTest(identifier=name): -                conversion = await ValidPythonIdentifier.convert(self.context, name) +                conversion = await PackageName.convert(self.context, name)                  self.assertEqual(name, conversion) -    async def test_valid_python_identifier_for_invalid(self): -        """ValidPythonIdentifier raises the proper exception for invalid identifiers.""" -        test_values = ('nested.stuff', '#####') +    async def test_package_name_for_invalid(self): +        """PackageName raises the proper exception for invalid package names.""" +        test_values = ('text_with_a_dot.', 'UpperCaseName', 'dashed-name')          for name in test_values:              with self.subTest(identifier=name): -                exception_message = f'`{name}` is not a valid Python identifier' -                with self.assertRaisesRegex(BadArgument, re.escape(exception_message)): -                    await ValidPythonIdentifier.convert(self.context, name) +                with self.assertRaises(BadArgument): +                    await PackageName.convert(self.context, name)      async def test_duration_converter_for_valid(self):          """Duration returns the correct `datetime` for valid duration strings.""" diff --git a/tests/helpers.py b/tests/helpers.py index 496363ae3..e3dc5fe5b 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -385,6 +385,7 @@ message_instance = discord.Message(state=state, channel=channel, data=message_da  # Create a Context instance to get a realistic MagicMock of `discord.ext.commands.Context`  context_instance = Context(message=unittest.mock.MagicMock(), prefix=unittest.mock.MagicMock()) +context_instance.invoked_from_error_handler = None  class MockContext(CustomMockMixin, unittest.mock.MagicMock): @@ -402,6 +403,7 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock):          self.guild = kwargs.get('guild', MockGuild())          self.author = kwargs.get('author', MockMember())          self.channel = kwargs.get('channel', MockTextChannel()) +        self.invoked_from_error_handler = kwargs.get('invoked_from_error_handler', False)  attachment_instance = discord.Attachment(data=unittest.mock.MagicMock(id=1), state=unittest.mock.MagicMock()) | 
