aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar Boris Muratov <[email protected]>2021-05-04 11:39:24 +0300
committerGravatar GitHub <[email protected]>2021-05-04 11:39:24 +0300
commiteaba4d3e9df320b3086cdc6f357525c2a53c1482 (patch)
tree12452ea1c9b292bfbd11435ec57c32cdf174c8d5
parentRemoved opinions from text. (diff)
parentMerge pull request #1556 from ToxicKidz/mod-ping-off-embed-timestamp (diff)
Merge branch 'main' into str-join-tag
-rw-r--r--LICENSE-THIRD-PARTY30
-rw-r--r--Pipfile4
-rw-r--r--Pipfile.lock241
-rw-r--r--bot/constants.py11
-rw-r--r--bot/converters.py44
-rw-r--r--bot/decorators.py57
-rw-r--r--bot/exts/backend/branding/_cog.py20
-rw-r--r--bot/exts/backend/error_handler.py26
-rw-r--r--bot/exts/filters/antispam.py19
-rw-r--r--bot/exts/info/code_snippets.py265
-rw-r--r--bot/exts/info/doc.py485
-rw-r--r--bot/exts/info/doc/__init__.py16
-rw-r--r--bot/exts/info/doc/_batch_parser.py186
-rw-r--r--bot/exts/info/doc/_cog.py442
-rw-r--r--bot/exts/info/doc/_html.py136
-rw-r--r--bot/exts/info/doc/_inventory_parser.py126
-rw-r--r--bot/exts/info/doc/_markdown.py58
-rw-r--r--bot/exts/info/doc/_parsing.py256
-rw-r--r--bot/exts/info/doc/_redis_cache.py70
-rw-r--r--bot/exts/info/information.py5
-rw-r--r--bot/exts/info/source.py8
-rw-r--r--bot/exts/moderation/infraction/infractions.py18
-rw-r--r--bot/exts/moderation/infraction/superstarify.py8
-rw-r--r--bot/exts/moderation/modlog.py9
-rw-r--r--bot/exts/moderation/modpings.py138
-rw-r--r--bot/exts/moderation/stream.py80
-rw-r--r--bot/exts/utils/clean.py8
-rw-r--r--bot/exts/utils/reminders.py13
-rw-r--r--bot/exts/utils/snekbox.py10
-rw-r--r--bot/exts/utils/utils.py30
-rw-r--r--bot/log.py37
-rw-r--r--bot/pagination.py36
-rw-r--r--bot/resources/tags/customchecks.md21
-rw-r--r--bot/utils/checks.py8
-rw-r--r--bot/utils/function.py72
-rw-r--r--bot/utils/lock.py37
-rw-r--r--bot/utils/messages.py70
-rw-r--r--bot/utils/scheduling.py10
-rw-r--r--config-default.yml15
-rw-r--r--tests/README.md2
-rw-r--r--tests/bot/exts/backend/test_error_handler.py550
-rw-r--r--tests/bot/exts/info/doc/__init__.py0
-rw-r--r--tests/bot/exts/info/doc/test_parsing.py66
-rw-r--r--tests/bot/exts/info/test_information.py2
-rw-r--r--tests/bot/test_converters.py21
-rw-r--r--tests/helpers.py2
46 files changed, 2887 insertions, 881 deletions
diff --git a/LICENSE-THIRD-PARTY b/LICENSE-THIRD-PARTY
index eacd9b952..ab715630d 100644
--- a/LICENSE-THIRD-PARTY
+++ b/LICENSE-THIRD-PARTY
@@ -35,6 +35,36 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
---------------------------------------------------------------------------------------------------
+ BSD 2-Clause License
+Applies to:
+ - Copyright (c) 2007-2020 by the Sphinx team (see AUTHORS file). All rights reserved.
+ - bot/cogs/doc/inventory_parser.py: _load_v1, _load_v2 and ZlibStreamReader.__aiter__.
+---------------------------------------------------------------------------------------------------
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+* Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+---------------------------------------------------------------------------------------------------
PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2
Applies to:
- Copyright © 2001-2020 Python Software Foundation. All rights reserved.
diff --git a/Pipfile b/Pipfile
index 2ac5645dd..e924f5ddb 100644
--- a/Pipfile
+++ b/Pipfile
@@ -20,15 +20,13 @@ emoji = "~=0.6"
feedparser = "~=5.2"
fuzzywuzzy = "~=0.17"
lxml = "~=4.4"
-markdownify = "==0.5.3"
+markdownify = "==0.6.1"
more_itertools = "~=8.2"
python-dateutil = "~=2.8"
python-frontmatter = "~=1.0.0"
pyyaml = "~=5.1"
regex = "==2021.4.4"
-requests = "~=2.22"
sentry-sdk = "~=0.19"
-sphinx = "~=2.2"
statsd = "~=3.3"
[dev-packages]
diff --git a/Pipfile.lock b/Pipfile.lock
index d6792ac35..1e1a8167b 100644
--- a/Pipfile.lock
+++ b/Pipfile.lock
@@ -1,7 +1,7 @@
{
"_meta": {
"hash": {
- "sha256": "fc3421fc4c95d73b620f2b8b0a7dea288d4fc559e0d288ed4ad6cf4eb312f630"
+ "sha256": "e35c9bad81b01152ad3e10b85f1abf5866aa87b9d87e03bc30bdb9d37668ccae"
},
"pipfile-spec": 6,
"requires": {
@@ -99,13 +99,6 @@
"markers": "python_version >= '3.6'",
"version": "==3.3.1"
},
- "alabaster": {
- "hashes": [
- "sha256:446438bdcca0e05bd45ea2de1668c1d9b032e1a9154c2c259092d77031ddd359",
- "sha256:a661d72d58e6ea8a57f7a86e37d86716863ee5e92788398526d58b26a4e4dc02"
- ],
- "version": "==0.7.12"
- },
"arrow": {
"hashes": [
"sha256:3515630f11a15c61dcb4cdd245883270dd334c83f3e639824e65a4b79cc48543",
@@ -142,14 +135,6 @@
"markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'",
"version": "==20.3.0"
},
- "babel": {
- "hashes": [
- "sha256:9d35c22fcc79893c3ecc85ac4a56cde1ecf3f19c540bba0922308a6c06ca6fa5",
- "sha256:da031ab54472314f210b0adcff1588ee5d1d1d0ba4dbd07b94dba82bde791e05"
- ],
- "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'",
- "version": "==2.9.0"
- },
"beautifulsoup4": {
"hashes": [
"sha256:4c98143716ef1cb40bf7f39a8e3eec8f8b009509e74904ba3a7b315431577e35",
@@ -221,7 +206,6 @@
"sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b",
"sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"
],
- "index": "pypi",
"markers": "sys_platform == 'win32'",
"version": "==0.4.4"
},
@@ -249,14 +233,6 @@
"index": "pypi",
"version": "==1.6.0"
},
- "docutils": {
- "hashes": [
- "sha256:a71042bb7207c03d5647f280427f14bfbd1a65c9eb84f4b341d85fafb6bb4bdf",
- "sha256:e2ffeea817964356ba4470efba7c2f42b6b0de0b04e66378507e3e2504bbff4c"
- ],
- "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'",
- "version": "==0.17"
- },
"emoji": {
"hashes": [
"sha256:e42da4f8d648f8ef10691bc246f682a1ec6b18373abfd9be10ec0b398823bd11"
@@ -345,27 +321,11 @@
},
"idna": {
"hashes": [
- "sha256:b307872f855b18632ce0c21c5e45be78c0ea7ae4c15c828c20788b26921eb3f6",
- "sha256:b97d804b1e9b523befed77c48dacec60e6dcb0b5391d57af6a65a312a90648c0"
+ "sha256:5205d03e7bcbb919cc9c19885f9920d622ca52448306f2377daede5cf3faac16",
+ "sha256:c5b02147e01ea9920e6b0a3f1f7bb833612d507592c837a6c49552768f4054e1"
],
- "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'",
- "version": "==2.10"
- },
- "imagesize": {
- "hashes": [
- "sha256:6965f19a6a2039c7d48bca7dba2473069ff854c36ae6f19d2cde309d998228a1",
- "sha256:b1f6b5a4eab1f73479a50fb79fcf729514a900c341d8503d62a62dbc4127a2b1"
- ],
- "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'",
- "version": "==1.2.0"
- },
- "jinja2": {
- "hashes": [
- "sha256:03e47ad063331dd6a3f04a43eddca8a966a26ba0c5b7207a9a9e4e08f1b29419",
- "sha256:a6d58433de0ae800347cab1fa3043cebbabe8baa9d29e668f1c768cb87a333c6"
- ],
- "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'",
- "version": "==2.11.3"
+ "markers": "python_version >= '3.4'",
+ "version": "==3.1"
},
"lxml": {
"hashes": [
@@ -411,69 +371,11 @@
},
"markdownify": {
"hashes": [
- "sha256:30be8340724e706c9e811c27fe8c1542cf74a15b46827924fff5c54b40dd9b0d",
- "sha256:a69588194fd76634f0139d6801b820fd652dc5eeba9530e90d323dfdc0155252"
+ "sha256:31d7c13ac2ada8bfc7535a25fee6622ca720e1b5f2d4a9cbc429d167c21f886d",
+ "sha256:7489fd5c601536996a376c4afbcd1dd034db7690af807120681461e82fbc0acc"
],
"index": "pypi",
- "version": "==0.5.3"
- },
- "markupsafe": {
- "hashes": [
- "sha256:00bc623926325b26bb9605ae9eae8a215691f33cae5df11ca5424f06f2d1f473",
- "sha256:09027a7803a62ca78792ad89403b1b7a73a01c8cb65909cd876f7fcebd79b161",
- "sha256:09c4b7f37d6c648cb13f9230d847adf22f8171b1ccc4d5682398e77f40309235",
- "sha256:1027c282dad077d0bae18be6794e6b6b8c91d58ed8a8d89a89d59693b9131db5",
- "sha256:13d3144e1e340870b25e7b10b98d779608c02016d5184cfb9927a9f10c689f42",
- "sha256:195d7d2c4fbb0ee8139a6cf67194f3973a6b3042d742ebe0a9ed36d8b6f0c07f",
- "sha256:22c178a091fc6630d0d045bdb5992d2dfe14e3259760e713c490da5323866c39",
- "sha256:24982cc2533820871eba85ba648cd53d8623687ff11cbb805be4ff7b4c971aff",
- "sha256:29872e92839765e546828bb7754a68c418d927cd064fd4708fab9fe9c8bb116b",
- "sha256:2beec1e0de6924ea551859edb9e7679da6e4870d32cb766240ce17e0a0ba2014",
- "sha256:3b8a6499709d29c2e2399569d96719a1b21dcd94410a586a18526b143ec8470f",
- "sha256:43a55c2930bbc139570ac2452adf3d70cdbb3cfe5912c71cdce1c2c6bbd9c5d1",
- "sha256:46c99d2de99945ec5cb54f23c8cd5689f6d7177305ebff350a58ce5f8de1669e",
- "sha256:500d4957e52ddc3351cabf489e79c91c17f6e0899158447047588650b5e69183",
- "sha256:535f6fc4d397c1563d08b88e485c3496cf5784e927af890fb3c3aac7f933ec66",
- "sha256:596510de112c685489095da617b5bcbbac7dd6384aeebeda4df6025d0256a81b",
- "sha256:62fe6c95e3ec8a7fad637b7f3d372c15ec1caa01ab47926cfdf7a75b40e0eac1",
- "sha256:6788b695d50a51edb699cb55e35487e430fa21f1ed838122d722e0ff0ac5ba15",
- "sha256:6dd73240d2af64df90aa7c4e7481e23825ea70af4b4922f8ede5b9e35f78a3b1",
- "sha256:6f1e273a344928347c1290119b493a1f0303c52f5a5eae5f16d74f48c15d4a85",
- "sha256:6fffc775d90dcc9aed1b89219549b329a9250d918fd0b8fa8d93d154918422e1",
- "sha256:717ba8fe3ae9cc0006d7c451f0bb265ee07739daf76355d06366154ee68d221e",
- "sha256:79855e1c5b8da654cf486b830bd42c06e8780cea587384cf6545b7d9ac013a0b",
- "sha256:7c1699dfe0cf8ff607dbdcc1e9b9af1755371f92a68f706051cc8c37d447c905",
- "sha256:7fed13866cf14bba33e7176717346713881f56d9d2bcebab207f7a036f41b850",
- "sha256:84dee80c15f1b560d55bcfe6d47b27d070b4681c699c572af2e3c7cc90a3b8e0",
- "sha256:88e5fcfb52ee7b911e8bb6d6aa2fd21fbecc674eadd44118a9cc3863f938e735",
- "sha256:8defac2f2ccd6805ebf65f5eeb132adcf2ab57aa11fdf4c0dd5169a004710e7d",
- "sha256:98bae9582248d6cf62321dcb52aaf5d9adf0bad3b40582925ef7c7f0ed85fceb",
- "sha256:98c7086708b163d425c67c7a91bad6e466bb99d797aa64f965e9d25c12111a5e",
- "sha256:9add70b36c5666a2ed02b43b335fe19002ee5235efd4b8a89bfcf9005bebac0d",
- "sha256:9bf40443012702a1d2070043cb6291650a0841ece432556f784f004937f0f32c",
- "sha256:a6a744282b7718a2a62d2ed9d993cad6f5f585605ad352c11de459f4108df0a1",
- "sha256:acf08ac40292838b3cbbb06cfe9b2cb9ec78fce8baca31ddb87aaac2e2dc3bc2",
- "sha256:ade5e387d2ad0d7ebf59146cc00c8044acbd863725f887353a10df825fc8ae21",
- "sha256:b00c1de48212e4cc9603895652c5c410df699856a2853135b3967591e4beebc2",
- "sha256:b1282f8c00509d99fef04d8ba936b156d419be841854fe901d8ae224c59f0be5",
- "sha256:b1dba4527182c95a0db8b6060cc98ac49b9e2f5e64320e2b56e47cb2831978c7",
- "sha256:b2051432115498d3562c084a49bba65d97cf251f5a331c64a12ee7e04dacc51b",
- "sha256:b7d644ddb4dbd407d31ffb699f1d140bc35478da613b441c582aeb7c43838dd8",
- "sha256:ba59edeaa2fc6114428f1637ffff42da1e311e29382d81b339c1817d37ec93c6",
- "sha256:bf5aa3cbcfdf57fa2ee9cd1822c862ef23037f5c832ad09cfea57fa846dec193",
- "sha256:c8716a48d94b06bb3b2524c2b77e055fb313aeb4ea620c8dd03a105574ba704f",
- "sha256:caabedc8323f1e93231b52fc32bdcde6db817623d33e100708d9a68e1f53b26b",
- "sha256:cd5df75523866410809ca100dc9681e301e3c27567cf498077e8551b6d20e42f",
- "sha256:cdb132fc825c38e1aeec2c8aa9338310d29d337bebbd7baa06889d09a60a1fa2",
- "sha256:d53bc011414228441014aa71dbec320c66468c1030aae3a6e29778a3382d96e5",
- "sha256:d73a845f227b0bfe8a7455ee623525ee656a9e2e749e4742706d80a6065d5e2c",
- "sha256:d9be0ba6c527163cbed5e0857c451fcd092ce83947944d6c14bc95441203f032",
- "sha256:e249096428b3ae81b08327a63a485ad0878de3fb939049038579ac0ef61e17e7",
- "sha256:e8313f01ba26fbbe36c7be1966a7b7424942f670f38e666995b88d012765b9be",
- "sha256:feb7b34d6325451ef96bc0e36e1a6c0c1c64bc1fbec4b854f4529e51887b1621"
- ],
- "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'",
- "version": "==1.1.1"
+ "version": "==0.6.1"
},
"more-itertools": {
"hashes": [
@@ -533,14 +435,6 @@
"markers": "python_version >= '3.5'",
"version": "==4.0.2"
},
- "packaging": {
- "hashes": [
- "sha256:5b327ac1320dc863dca72f4514ecc086f31186744b84a230374cc1fd776feae5",
- "sha256:67714da7f7bc052e064859c05c595155bd1ee9f69f76557e21f051443c20947a"
- ],
- "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'",
- "version": "==20.9"
- },
"pamqp": {
"hashes": [
"sha256:2f81b5c186f668a67f165193925b6bfd83db4363a6222f599517f29ecee60b02",
@@ -590,31 +484,6 @@
"markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'",
"version": "==2.20"
},
- "pygments": {
- "hashes": [
- "sha256:2656e1a6edcdabf4275f9a3640db59fd5de107d88e8663c5d4e9a0fa62f77f94",
- "sha256:534ef71d539ae97d4c3a4cf7d6f110f214b0e687e92f9cb9d2a3b0d3101289c8"
- ],
- "markers": "python_version >= '3.5'",
- "version": "==2.8.1"
- },
- "pyparsing": {
- "hashes": [
- "sha256:c203ec8783bf771a155b207279b9bccb8dea02d8f0c9e5f8ead507bc3246ecc1",
- "sha256:ef9d7589ef3c200abe66653d3f1ab1033c3c419ae9b9bdb1240a85b024efc88b"
- ],
- "markers": "python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3'",
- "version": "==2.4.7"
- },
- "pyreadline": {
- "hashes": [
- "sha256:4530592fc2e85b25b1a9f79664433da09237c1a270e4d78ea5aa3a2c7229e2d1",
- "sha256:65540c21bfe14405a3a77e4c085ecfce88724743a4ead47c66b84defcf82c32e",
- "sha256:9ce5fa65b8992dfa373bddc5b6e0864ead8f291c94fbfec05fbd5c836162e67b"
- ],
- "markers": "sys_platform == 'win32'",
- "version": "==2.1"
- },
"python-dateutil": {
"hashes": [
"sha256:73ebfe9dbf22e832286dafa60473e4cd239f8592f699aa5adaf10050e6e1823c",
@@ -631,13 +500,6 @@
"index": "pypi",
"version": "==1.0.0"
},
- "pytz": {
- "hashes": [
- "sha256:83a4a90894bf38e243cf052c8b58f381bfe9a7a483f6a9cab140bc7f702ac4da",
- "sha256:eb10ce3e7736052ed3623d49975ce333bcd712c7bb19a58b9e2089d4057d0798"
- ],
- "version": "==2021.1"
- },
"pyyaml": {
"hashes": [
"sha256:08682f6b72c722394747bddaf0aa62277e02557c0fd1c42cb853016a38f8dedf",
@@ -728,14 +590,6 @@
"index": "pypi",
"version": "==2021.4.4"
},
- "requests": {
- "hashes": [
- "sha256:27973dd4a904a4f13b263a19c866c13b92a39ed1c964655f025f3f8d3d75b804",
- "sha256:c210084e36a42ae6b9219e00e48287def368a26d03a048ddad7bfee44f75871e"
- ],
- "index": "pypi",
- "version": "==2.25.1"
- },
"sentry-sdk": {
"hashes": [
"sha256:4ae8d1ced6c67f1c8ea51d82a16721c166c489b76876c9f2c202b8a50334b237",
@@ -749,16 +603,9 @@
"sha256:30639c035cdb23534cd4aa2dd52c3bf48f06e5f4a941509c8bafd8ce11080259",
"sha256:8b74bedcbbbaca38ff6d7491d76f2b06b3592611af620f8426e82dddb04a5ced"
],
- "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'",
+ "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2'",
"version": "==1.15.0"
},
- "snowballstemmer": {
- "hashes": [
- "sha256:b51b447bea85f9968c13b650126a888aabd4cb4463fca868ec596826325dedc2",
- "sha256:e997baa4f2e9139951b6f4c631bad912dfd3c792467e2f03d7239464af90e914"
- ],
- "version": "==2.1.0"
- },
"sortedcontainers": {
"hashes": [
"sha256:37257a32add0a3ee490bb170b599e93095eed89a55da91fa9f48753ea12fd73f",
@@ -774,62 +621,6 @@
"markers": "python_version >= '3.0'",
"version": "==2.2.1"
},
- "sphinx": {
- "hashes": [
- "sha256:b4c750d546ab6d7e05bdff6ac24db8ae3e8b8253a3569b754e445110a0a12b66",
- "sha256:fc312670b56cb54920d6cc2ced455a22a547910de10b3142276495ced49231cb"
- ],
- "index": "pypi",
- "version": "==2.4.4"
- },
- "sphinxcontrib-applehelp": {
- "hashes": [
- "sha256:806111e5e962be97c29ec4c1e7fe277bfd19e9652fb1a4392105b43e01af885a",
- "sha256:a072735ec80e7675e3f432fcae8610ecf509c5f1869d17e2eecff44389cdbc58"
- ],
- "markers": "python_version >= '3.5'",
- "version": "==1.0.2"
- },
- "sphinxcontrib-devhelp": {
- "hashes": [
- "sha256:8165223f9a335cc1af7ffe1ed31d2871f325254c0423bc0c4c7cd1c1e4734a2e",
- "sha256:ff7f1afa7b9642e7060379360a67e9c41e8f3121f2ce9164266f61b9f4b338e4"
- ],
- "markers": "python_version >= '3.5'",
- "version": "==1.0.2"
- },
- "sphinxcontrib-htmlhelp": {
- "hashes": [
- "sha256:3c0bc24a2c41e340ac37c85ced6dafc879ab485c095b1d65d2461ac2f7cca86f",
- "sha256:e8f5bb7e31b2dbb25b9cc435c8ab7a79787ebf7f906155729338f3156d93659b"
- ],
- "markers": "python_version >= '3.5'",
- "version": "==1.0.3"
- },
- "sphinxcontrib-jsmath": {
- "hashes": [
- "sha256:2ec2eaebfb78f3f2078e73666b1415417a116cc848b72e5172e596c871103178",
- "sha256:a9925e4a4587247ed2191a22df5f6970656cb8ca2bd6284309578f2153e0c4b8"
- ],
- "markers": "python_version >= '3.5'",
- "version": "==1.0.1"
- },
- "sphinxcontrib-qthelp": {
- "hashes": [
- "sha256:4c33767ee058b70dba89a6fc5c1892c0d57a54be67ddd3e7875a18d14cba5a72",
- "sha256:bd9fc24bcb748a8d51fd4ecaade681350aa63009a347a8c14e637895444dfab6"
- ],
- "markers": "python_version >= '3.5'",
- "version": "==1.0.3"
- },
- "sphinxcontrib-serializinghtml": {
- "hashes": [
- "sha256:eaa0eccc86e982a9b939b2b82d12cc5d013385ba5eadcc7e4fed23f4405f77bc",
- "sha256:f242a81d423f59617a8e5cf16f5d4d74e28ee9a66f9e5b637a18082991db5a9a"
- ],
- "markers": "python_version >= '3.5'",
- "version": "==1.1.4"
- },
"statsd": {
"hashes": [
"sha256:c610fb80347fca0ef62666d241bce64184bd7cc1efe582f9690e045c25535eaa",
@@ -1103,11 +894,11 @@
},
"idna": {
"hashes": [
- "sha256:b307872f855b18632ce0c21c5e45be78c0ea7ae4c15c828c20788b26921eb3f6",
- "sha256:b97d804b1e9b523befed77c48dacec60e6dcb0b5391d57af6a65a312a90648c0"
+ "sha256:5205d03e7bcbb919cc9c19885f9920d622ca52448306f2377daede5cf3faac16",
+ "sha256:c5b02147e01ea9920e6b0a3f1f7bb833612d507592c837a6c49552768f4054e1"
],
- "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'",
- "version": "==2.10"
+ "markers": "python_version >= '3.4'",
+ "version": "==3.1"
},
"mccabe": {
"hashes": [
@@ -1203,7 +994,7 @@
"sha256:27973dd4a904a4f13b263a19c866c13b92a39ed1c964655f025f3f8d3d75b804",
"sha256:c210084e36a42ae6b9219e00e48287def368a26d03a048ddad7bfee44f75871e"
],
- "index": "pypi",
+ "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'",
"version": "==2.25.1"
},
"six": {
@@ -1211,7 +1002,7 @@
"sha256:30639c035cdb23534cd4aa2dd52c3bf48f06e5f4a941509c8bafd8ce11080259",
"sha256:8b74bedcbbbaca38ff6d7491d76f2b06b3592611af620f8426e82dddb04a5ced"
],
- "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'",
+ "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2'",
"version": "==1.15.0"
},
"snowballstemmer": {
@@ -1226,7 +1017,7 @@
"sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b",
"sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"
],
- "markers": "python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3'",
+ "markers": "python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2'",
"version": "==0.10.2"
},
"urllib3": {
diff --git a/bot/constants.py b/bot/constants.py
index 6d14bbb3a..7b2a38079 100644
--- a/bot/constants.py
+++ b/bot/constants.py
@@ -175,13 +175,14 @@ class YAMLGetter(type):
if cls.subsection is not None:
return _CONFIG_YAML[cls.section][cls.subsection][name]
return _CONFIG_YAML[cls.section][name]
- except KeyError:
+ except KeyError as e:
dotted_path = '.'.join(
(cls.section, cls.subsection, name)
if cls.subsection is not None else (cls.section, name)
)
- log.critical(f"Tried accessing configuration variable at `{dotted_path}`, but it could not be found.")
- raise
+ # Only an INFO log since this can be caught through `hasattr` or `getattr`.
+ log.info(f"Tried accessing configuration variable at `{dotted_path}`, but it could not be found.")
+ raise AttributeError(repr(name)) from e
def __getitem__(cls, name):
return cls.__getattr__(name)
@@ -199,6 +200,7 @@ class Bot(metaclass=YAMLGetter):
prefix: str
sentry_dsn: Optional[str]
token: str
+ trace_loggers: Optional[str]
class Redis(metaclass=YAMLGetter):
@@ -279,6 +281,8 @@ class Emojis(metaclass=YAMLGetter):
badge_partner: str
badge_staff: str
badge_verified_bot_developer: str
+ verified_bot: str
+ bot: str
defcon_shutdown: str # noqa: E704
defcon_unshutdown: str # noqa: E704
@@ -491,6 +495,7 @@ class Roles(metaclass=YAMLGetter):
domain_leads: int
helpers: int
moderators: int
+ mod_team: int
owners: int
project_leads: int
diff --git a/bot/converters.py b/bot/converters.py
index 67525cd4d..3bf05cfb3 100644
--- a/bot/converters.py
+++ b/bot/converters.py
@@ -15,6 +15,7 @@ from discord.utils import DISCORD_EPOCH, snowflake_time
from bot.api import ResponseCodeError
from bot.constants import URLs
+from bot.exts.info.doc import _inventory_parser
from bot.utils.regex import INVITE_RE
from bot.utils.time import parse_duration_string
@@ -127,22 +128,20 @@ class ValidFilterListType(Converter):
return list_type
-class ValidPythonIdentifier(Converter):
+class PackageName(Converter):
"""
- A converter that checks whether the given string is a valid Python identifier.
+ A converter that checks whether the given string is a valid package name.
- This is used to have package names that correspond to how you would use the package in your
- code, e.g. `import package`.
-
- Raises `BadArgument` if the argument is not a valid Python identifier, and simply passes through
- the given argument otherwise.
+ Package names are used for stats and are restricted to the a-z and _ characters.
"""
- @staticmethod
- async def convert(ctx: Context, argument: str) -> str:
- """Checks whether the given string is a valid Python identifier."""
- if not argument.isidentifier():
- raise BadArgument(f"`{argument}` is not a valid Python identifier")
+ PACKAGE_NAME_RE = re.compile(r"[^a-z0-9_]")
+
+ @classmethod
+ async def convert(cls, ctx: Context, argument: str) -> str:
+ """Checks whether the given string is a valid package name."""
+ if cls.PACKAGE_NAME_RE.search(argument):
+ raise BadArgument("The provided package name is not valid; please only use the _, 0-9, and a-z characters.")
return argument
@@ -178,6 +177,27 @@ class ValidURL(Converter):
return url
+class Inventory(Converter):
+ """
+ Represents an Intersphinx inventory URL.
+
+ This converter checks whether intersphinx accepts the given inventory URL, and raises
+ `BadArgument` if that is not the case or if the url is unreachable.
+
+ Otherwise, it returns the url and the fetched inventory dict in a tuple.
+ """
+
+ @staticmethod
+ async def convert(ctx: Context, url: str) -> t.Tuple[str, _inventory_parser.InventoryDict]:
+ """Convert url to Intersphinx inventory URL."""
+ await ctx.trigger_typing()
+ if (inventory := await _inventory_parser.fetch_inventory(url)) is None:
+ raise BadArgument(
+ f"Failed to fetch inventory file after {_inventory_parser.FAILED_REQUEST_ATTEMPTS} attempts."
+ )
+ return url, inventory
+
+
class Snowflake(IDConverter):
"""
Converts to an int if the argument is a valid Discord snowflake.
diff --git a/bot/decorators.py b/bot/decorators.py
index 0b50cc365..e971a5bd3 100644
--- a/bot/decorators.py
+++ b/bot/decorators.py
@@ -1,9 +1,9 @@
import asyncio
import functools
import logging
+import types
import typing as t
from contextlib import suppress
-from functools import wraps
from discord import Member, NotFound
from discord.ext import commands
@@ -11,7 +11,8 @@ from discord.ext.commands import Cog, Context
from bot.constants import Channels, DEBUG_MODE, RedirectOutput
from bot.utils import function
-from bot.utils.checks import in_whitelist_check
+from bot.utils.checks import ContextCheckFailure, in_whitelist_check
+from bot.utils.function import command_wraps
log = logging.getLogger(__name__)
@@ -44,6 +45,49 @@ def in_whitelist(
return commands.check(predicate)
+class NotInBlacklistCheckFailure(ContextCheckFailure):
+ """Raised when the 'not_in_blacklist' check fails."""
+
+
+def not_in_blacklist(
+ *,
+ channels: t.Container[int] = (),
+ categories: t.Container[int] = (),
+ roles: t.Container[int] = (),
+ override_roles: t.Container[int] = (),
+ redirect: t.Optional[int] = Channels.bot_commands,
+ fail_silently: bool = False,
+) -> t.Callable:
+ """
+ Check if a command was not issued in a blacklisted context.
+
+ The blacklists that can be provided are:
+
+ - `channels`: a container with channel ids for blacklisted channels
+ - `categories`: a container with category ids for blacklisted categories
+ - `roles`: a container with role ids for blacklisted roles
+
+ If the command was invoked in a context that was blacklisted, the member is either
+ redirected to the `redirect` channel that was passed (default: #bot-commands) or simply
+ told that they're not allowed to use this particular command (if `None` was passed).
+
+ The blacklist can be overridden through the roles specified in `override_roles`.
+ """
+ def predicate(ctx: Context) -> bool:
+ """Check if command was issued in a blacklisted context."""
+ not_blacklisted = not in_whitelist_check(ctx, channels, categories, roles, fail_silently=True)
+ overridden = in_whitelist_check(ctx, roles=override_roles, fail_silently=True)
+
+ success = not_blacklisted or overridden
+
+ if not success and not fail_silently:
+ raise NotInBlacklistCheckFailure(redirect)
+
+ return success
+
+ return commands.check(predicate)
+
+
def has_no_roles(*roles: t.Union[str, int]) -> t.Callable:
"""
Returns True if the user does not have any of the roles specified.
@@ -71,8 +115,8 @@ def redirect_output(destination_channel: int, bypass_roles: t.Container[int] = N
This decorator must go before (below) the `command` decorator.
"""
- def wrap(func: t.Callable) -> t.Callable:
- @wraps(func)
+ def wrap(func: types.FunctionType) -> types.FunctionType:
+ @command_wraps(func)
async def inner(self: Cog, ctx: Context, *args, **kwargs) -> None:
if ctx.channel.id == destination_channel:
log.trace(f"Command {ctx.command.name} was invoked in destination_channel, not redirecting")
@@ -106,7 +150,6 @@ def redirect_output(destination_channel: int, bypass_roles: t.Container[int] = N
with suppress(NotFound):
await ctx.message.delete()
log.trace("Redirect output: Deleted invocation message")
-
return inner
return wrap
@@ -123,8 +166,8 @@ def respect_role_hierarchy(member_arg: function.Argument) -> t.Callable:
This decorator must go before (below) the `command` decorator.
"""
- def decorator(func: t.Callable) -> t.Callable:
- @wraps(func)
+ def decorator(func: types.FunctionType) -> types.FunctionType:
+ @command_wraps(func)
async def wrapper(*args, **kwargs) -> None:
log.trace(f"{func.__name__}: respect role hierarchy decorator called")
diff --git a/bot/exts/backend/branding/_cog.py b/bot/exts/backend/branding/_cog.py
index 0a4ddcc88..47c379a34 100644
--- a/bot/exts/backend/branding/_cog.py
+++ b/bot/exts/backend/branding/_cog.py
@@ -3,12 +3,13 @@ import contextlib
import logging
import random
import typing as t
-from datetime import datetime, time, timedelta
+from datetime import timedelta
from enum import Enum
from operator import attrgetter
import async_timeout
import discord
+from arrow import Arrow
from async_rediscache import RedisCache
from discord.ext import commands, tasks
@@ -57,6 +58,8 @@ def extract_event_duration(event: Event) -> str:
Extract a human-readable, year-agnostic duration string from `event`.
In the case that `event` is a fallback event, resolves to 'Fallback'.
+
+ For 1-day events, only the single date is shown, instead of a period.
"""
if event.meta.is_fallback:
return "Fallback"
@@ -65,6 +68,9 @@ def extract_event_duration(event: Event) -> str:
start_date = event.meta.start_date.strftime(fmt)
end_date = event.meta.end_date.strftime(fmt)
+ if start_date == end_date:
+ return start_date
+
return f"{start_date} - {end_date}"
@@ -208,7 +214,7 @@ class Branding(commands.Cog):
if success:
await self.cache_icons.increment(next_icon) # Push the icon into the next iteration.
- timestamp = datetime.utcnow().timestamp()
+ timestamp = Arrow.utcnow().timestamp()
await self.cache_information.set("last_rotation_timestamp", timestamp)
return success
@@ -229,8 +235,8 @@ class Branding(commands.Cog):
await self.rotate_icons()
return
- last_rotation = datetime.fromtimestamp(last_rotation_timestamp)
- difference = (datetime.utcnow() - last_rotation) + timedelta(minutes=5)
+ last_rotation = Arrow.utcfromtimestamp(last_rotation_timestamp)
+ difference = (Arrow.utcnow() - last_rotation) + timedelta(minutes=5)
log.trace(f"Icons last rotated at {last_rotation} (difference: {difference}).")
@@ -485,11 +491,11 @@ class Branding(commands.Cog):
await self.daemon_loop()
log.trace("Daemon before: calculating time to sleep before loop begins.")
- now = datetime.utcnow()
+ now = Arrow.utcnow()
# The actual midnight moment is offset into the future to prevent issues with imprecise sleep.
- tomorrow = now + timedelta(days=1)
- midnight = datetime.combine(tomorrow, time(minute=1))
+ tomorrow = now.shift(days=1)
+ midnight = tomorrow.replace(hour=0, minute=1, second=0, microsecond=0)
sleep_secs = (midnight - now).total_seconds()
log.trace(f"Daemon before: sleeping {sleep_secs} seconds before next-up midnight: {midnight}.")
diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py
index 76ab7dfc2..d8de177f5 100644
--- a/bot/exts/backend/error_handler.py
+++ b/bot/exts/backend/error_handler.py
@@ -1,4 +1,3 @@
-import contextlib
import difflib
import logging
import typing as t
@@ -12,7 +11,7 @@ from bot.bot import Bot
from bot.constants import Colours, Icons, MODERATION_ROLES
from bot.converters import TagNameConverter
from bot.errors import InvalidInfractedUser, LockedResourceError
-from bot.utils.checks import InWhitelistCheckFailure
+from bot.utils.checks import ContextCheckFailure
log = logging.getLogger(__name__)
@@ -60,7 +59,7 @@ class ErrorHandler(Cog):
log.trace(f"Command {command} had its error already handled locally; ignoring.")
return
- if isinstance(e, errors.CommandNotFound) and not hasattr(ctx, "invoked_from_error_handler"):
+ if isinstance(e, errors.CommandNotFound) and not getattr(ctx, "invoked_from_error_handler", False):
if await self.try_silence(ctx):
return
# Try to look for a tag with the command's name
@@ -162,9 +161,8 @@ class ErrorHandler(Cog):
f"and the fallback tag failed validation in TagNameConverter."
)
else:
- with contextlib.suppress(ResponseCodeError):
- if await ctx.invoke(tags_get_command, tag_name=tag_name):
- return
+ if await ctx.invoke(tags_get_command, tag_name=tag_name):
+ return
if not any(role.id in MODERATION_ROLES for role in ctx.author.roles):
await self.send_command_suggestion(ctx, ctx.invoked_with)
@@ -214,32 +212,30 @@ class ErrorHandler(Cog):
* ArgumentParsingError: send an error message
* Other: send an error message and the help command
"""
- prepared_help_command = self.get_help_command(ctx)
-
if isinstance(e, errors.MissingRequiredArgument):
embed = self._get_error_embed("Missing required argument", e.param.name)
await ctx.send(embed=embed)
- await prepared_help_command
+ await self.get_help_command(ctx)
self.bot.stats.incr("errors.missing_required_argument")
elif isinstance(e, errors.TooManyArguments):
embed = self._get_error_embed("Too many arguments", str(e))
await ctx.send(embed=embed)
- await prepared_help_command
+ await self.get_help_command(ctx)
self.bot.stats.incr("errors.too_many_arguments")
elif isinstance(e, errors.BadArgument):
embed = self._get_error_embed("Bad argument", str(e))
await ctx.send(embed=embed)
- await prepared_help_command
+ await self.get_help_command(ctx)
self.bot.stats.incr("errors.bad_argument")
elif isinstance(e, errors.BadUnionArgument):
embed = self._get_error_embed("Bad argument", f"{e}\n{e.errors[-1]}")
await ctx.send(embed=embed)
- await prepared_help_command
+ await self.get_help_command(ctx)
self.bot.stats.incr("errors.bad_union_argument")
elif isinstance(e, errors.ArgumentParsingError):
embed = self._get_error_embed("Argument parsing error", str(e))
await ctx.send(embed=embed)
- prepared_help_command.close()
+ self.get_help_command(ctx).close()
self.bot.stats.incr("errors.argument_parsing_error")
else:
embed = self._get_error_embed(
@@ -247,7 +243,7 @@ class ErrorHandler(Cog):
"Something about your input seems off. Check the arguments and try again."
)
await ctx.send(embed=embed)
- await prepared_help_command
+ await self.get_help_command(ctx)
self.bot.stats.incr("errors.other_user_input_error")
@staticmethod
@@ -274,7 +270,7 @@ class ErrorHandler(Cog):
await ctx.send(
"Sorry, it looks like I don't have the permissions or roles I need to do that."
)
- elif isinstance(e, (InWhitelistCheckFailure, errors.NoPrivateMessage)):
+ elif isinstance(e, (ContextCheckFailure, errors.NoPrivateMessage)):
ctx.bot.stats.incr("errors.wrong_channel_or_dm_error")
await ctx.send(e)
diff --git a/bot/exts/filters/antispam.py b/bot/exts/filters/antispam.py
index af8528a68..7555e25a2 100644
--- a/bot/exts/filters/antispam.py
+++ b/bot/exts/filters/antispam.py
@@ -3,7 +3,7 @@ import logging
from collections.abc import Mapping
from dataclasses import dataclass, field
from datetime import datetime, timedelta
-from operator import itemgetter
+from operator import attrgetter, itemgetter
from typing import Dict, Iterable, List, Set
from discord import Colour, Member, Message, NotFound, Object, TextChannel
@@ -18,6 +18,7 @@ from bot.constants import (
)
from bot.converters import Duration
from bot.exts.moderation.modlog import ModLog
+from bot.utils import lock, scheduling
from bot.utils.messages import format_user, send_attachments
@@ -114,7 +115,7 @@ class AntiSpam(Cog):
self.message_deletion_queue = dict()
- self.bot.loop.create_task(self.alert_on_validation_error())
+ self.bot.loop.create_task(self.alert_on_validation_error(), name="AntiSpam.alert_on_validation_error")
@property
def mod_log(self) -> ModLog:
@@ -191,7 +192,10 @@ class AntiSpam(Cog):
if channel.id not in self.message_deletion_queue:
log.trace(f"Creating queue for channel `{channel.id}`")
self.message_deletion_queue[message.channel.id] = DeletionContext(channel)
- self.bot.loop.create_task(self._process_deletion_context(message.channel.id))
+ scheduling.create_task(
+ self._process_deletion_context(message.channel.id),
+ name=f"AntiSpam._process_deletion_context({message.channel.id})"
+ )
# Add the relevant of this trigger to the Deletion Context
await self.message_deletion_queue[message.channel.id].add(
@@ -201,16 +205,15 @@ class AntiSpam(Cog):
)
for member in members:
-
- # Fire it off as a background task to ensure
- # that the sleep doesn't block further tasks
- self.bot.loop.create_task(
- self.punish(message, member, full_reason)
+ scheduling.create_task(
+ self.punish(message, member, full_reason),
+ name=f"AntiSpam.punish(message={message.id}, member={member.id}, rule={rule_name})"
)
await self.maybe_delete_messages(channel, relevant_messages)
break
+ @lock.lock_arg("antispam.punish", "member", attrgetter("id"))
async def punish(self, msg: Message, member: Member, reason: str) -> None:
"""Punishes the given member for triggering an antispam rule."""
if not any(role.id == self.muted_role.id for role in member.roles):
diff --git a/bot/exts/info/code_snippets.py b/bot/exts/info/code_snippets.py
new file mode 100644
index 000000000..06885410b
--- /dev/null
+++ b/bot/exts/info/code_snippets.py
@@ -0,0 +1,265 @@
+import logging
+import re
+import textwrap
+from typing import Any
+from urllib.parse import quote_plus
+
+from aiohttp import ClientResponseError
+from discord import Message
+from discord.ext.commands import Cog
+
+from bot.bot import Bot
+from bot.constants import Channels
+from bot.utils.messages import wait_for_deletion
+
+log = logging.getLogger(__name__)
+
+GITHUB_RE = re.compile(
+ r'https://github\.com/(?P<repo>[a-zA-Z0-9-]+/[\w.-]+)/blob/'
+ r'(?P<path>[^#>]+)(\?[^#>]+)?(#L(?P<start_line>\d+)([-~:]L(?P<end_line>\d+))?)'
+)
+
+GITHUB_GIST_RE = re.compile(
+ r'https://gist\.github\.com/([a-zA-Z0-9-]+)/(?P<gist_id>[a-zA-Z0-9]+)/*'
+ r'(?P<revision>[a-zA-Z0-9]*)/*#file-(?P<file_path>[^#>]+?)(\?[^#>]+)?'
+ r'(-L(?P<start_line>\d+)([-~:]L(?P<end_line>\d+))?)'
+)
+
+GITHUB_HEADERS = {'Accept': 'application/vnd.github.v3.raw'}
+
+GITLAB_RE = re.compile(
+ r'https://gitlab\.com/(?P<repo>[\w.-]+/[\w.-]+)/\-/blob/(?P<path>[^#>]+)'
+ r'(\?[^#>]+)?(#L(?P<start_line>\d+)(-(?P<end_line>\d+))?)'
+)
+
+BITBUCKET_RE = re.compile(
+ r'https://bitbucket\.org/(?P<repo>[a-zA-Z0-9-]+/[\w.-]+)/src/(?P<ref>[0-9a-zA-Z]+)'
+ r'/(?P<file_path>[^#>]+)(\?[^#>]+)?(#lines-(?P<start_line>\d+)(:(?P<end_line>\d+))?)'
+)
+
+
+class CodeSnippets(Cog):
+ """
+ Cog that parses and sends code snippets to Discord.
+
+ Matches each message against a regex and prints the contents of all matched snippets.
+ """
+
+ async def _fetch_response(self, url: str, response_format: str, **kwargs) -> Any:
+ """Makes http requests using aiohttp."""
+ async with self.bot.http_session.get(url, raise_for_status=True, **kwargs) as response:
+ if response_format == 'text':
+ return await response.text()
+ elif response_format == 'json':
+ return await response.json()
+
+ def _find_ref(self, path: str, refs: tuple) -> tuple:
+ """Loops through all branches and tags to find the required ref."""
+ # Base case: there is no slash in the branch name
+ ref, file_path = path.split('/', 1)
+ # In case there are slashes in the branch name, we loop through all branches and tags
+ for possible_ref in refs:
+ if path.startswith(possible_ref['name'] + '/'):
+ ref = possible_ref['name']
+ file_path = path[len(ref) + 1:]
+ break
+ return ref, file_path
+
+ async def _fetch_github_snippet(
+ self,
+ repo: str,
+ path: str,
+ start_line: str,
+ end_line: str
+ ) -> str:
+ """Fetches a snippet from a GitHub repo."""
+ # Search the GitHub API for the specified branch
+ branches = await self._fetch_response(
+ f'https://api.github.com/repos/{repo}/branches',
+ 'json',
+ headers=GITHUB_HEADERS
+ )
+ tags = await self._fetch_response(f'https://api.github.com/repos/{repo}/tags', 'json', headers=GITHUB_HEADERS)
+ refs = branches + tags
+ ref, file_path = self._find_ref(path, refs)
+
+ file_contents = await self._fetch_response(
+ f'https://api.github.com/repos/{repo}/contents/{file_path}?ref={ref}',
+ 'text',
+ headers=GITHUB_HEADERS,
+ )
+ return self._snippet_to_codeblock(file_contents, file_path, start_line, end_line)
+
+ async def _fetch_github_gist_snippet(
+ self,
+ gist_id: str,
+ revision: str,
+ file_path: str,
+ start_line: str,
+ end_line: str
+ ) -> str:
+ """Fetches a snippet from a GitHub gist."""
+ gist_json = await self._fetch_response(
+ f'https://api.github.com/gists/{gist_id}{f"/{revision}" if len(revision) > 0 else ""}',
+ 'json',
+ headers=GITHUB_HEADERS,
+ )
+
+ # Check each file in the gist for the specified file
+ for gist_file in gist_json['files']:
+ if file_path == gist_file.lower().replace('.', '-'):
+ file_contents = await self._fetch_response(
+ gist_json['files'][gist_file]['raw_url'],
+ 'text',
+ )
+ return self._snippet_to_codeblock(file_contents, gist_file, start_line, end_line)
+ return ''
+
+ async def _fetch_gitlab_snippet(
+ self,
+ repo: str,
+ path: str,
+ start_line: str,
+ end_line: str
+ ) -> str:
+ """Fetches a snippet from a GitLab repo."""
+ enc_repo = quote_plus(repo)
+
+ # Searches the GitLab API for the specified branch
+ branches = await self._fetch_response(
+ f'https://gitlab.com/api/v4/projects/{enc_repo}/repository/branches',
+ 'json'
+ )
+ tags = await self._fetch_response(f'https://gitlab.com/api/v4/projects/{enc_repo}/repository/tags', 'json')
+ refs = branches + tags
+ ref, file_path = self._find_ref(path, refs)
+ enc_ref = quote_plus(ref)
+ enc_file_path = quote_plus(file_path)
+
+ file_contents = await self._fetch_response(
+ f'https://gitlab.com/api/v4/projects/{enc_repo}/repository/files/{enc_file_path}/raw?ref={enc_ref}',
+ 'text',
+ )
+ return self._snippet_to_codeblock(file_contents, file_path, start_line, end_line)
+
+ async def _fetch_bitbucket_snippet(
+ self,
+ repo: str,
+ ref: str,
+ file_path: str,
+ start_line: str,
+ end_line: str
+ ) -> str:
+ """Fetches a snippet from a BitBucket repo."""
+ file_contents = await self._fetch_response(
+ f'https://bitbucket.org/{quote_plus(repo)}/raw/{quote_plus(ref)}/{quote_plus(file_path)}',
+ 'text',
+ )
+ return self._snippet_to_codeblock(file_contents, file_path, start_line, end_line)
+
+ def _snippet_to_codeblock(self, file_contents: str, file_path: str, start_line: str, end_line: str) -> str:
+ """
+ Given the entire file contents and target lines, creates a code block.
+
+ First, we split the file contents into a list of lines and then keep and join only the required
+ ones together.
+
+ We then dedent the lines to look nice, and replace all ` characters with `\u200b to prevent
+ markdown injection.
+
+ Finally, we surround the code with ``` characters.
+ """
+ # Parse start_line and end_line into integers
+ if end_line is None:
+ start_line = end_line = int(start_line)
+ else:
+ start_line = int(start_line)
+ end_line = int(end_line)
+
+ split_file_contents = file_contents.splitlines()
+
+ # Make sure that the specified lines are in range
+ if start_line > end_line:
+ start_line, end_line = end_line, start_line
+ if start_line > len(split_file_contents) or end_line < 1:
+ return ''
+ start_line = max(1, start_line)
+ end_line = min(len(split_file_contents), end_line)
+
+ # Gets the code lines, dedents them, and inserts zero-width spaces to prevent Markdown injection
+ required = '\n'.join(split_file_contents[start_line - 1:end_line])
+ required = textwrap.dedent(required).rstrip().replace('`', '`\u200b')
+
+ # Extracts the code language and checks whether it's a "valid" language
+ language = file_path.split('/')[-1].split('.')[-1]
+ trimmed_language = language.replace('-', '').replace('+', '').replace('_', '')
+ is_valid_language = trimmed_language.isalnum()
+ if not is_valid_language:
+ language = ''
+
+ # Adds a label showing the file path to the snippet
+ if start_line == end_line:
+ ret = f'`{file_path}` line {start_line}\n'
+ else:
+ ret = f'`{file_path}` lines {start_line} to {end_line}\n'
+
+ if len(required) != 0:
+ return f'{ret}```{language}\n{required}```'
+ # Returns an empty codeblock if the snippet is empty
+ return f'{ret}``` ```'
+
+ def __init__(self, bot: Bot):
+ """Initializes the cog's bot."""
+ self.bot = bot
+
+ self.pattern_handlers = [
+ (GITHUB_RE, self._fetch_github_snippet),
+ (GITHUB_GIST_RE, self._fetch_github_gist_snippet),
+ (GITLAB_RE, self._fetch_gitlab_snippet),
+ (BITBUCKET_RE, self._fetch_bitbucket_snippet)
+ ]
+
+ @Cog.listener()
+ async def on_message(self, message: Message) -> None:
+ """Checks if the message has a snippet link, removes the embed, then sends the snippet contents."""
+ if not message.author.bot:
+ all_snippets = []
+
+ for pattern, handler in self.pattern_handlers:
+ for match in pattern.finditer(message.content):
+ try:
+ snippet = await handler(**match.groupdict())
+ all_snippets.append((match.start(), snippet))
+ except ClientResponseError as error:
+ error_message = error.message # noqa: B306
+ log.log(
+ logging.DEBUG if error.status == 404 else logging.ERROR,
+ f'Failed to fetch code snippet from {match[0]!r}: {error.status} '
+ f'{error_message} for GET {error.request_info.real_url.human_repr()}'
+ )
+
+ # Sorts the list of snippets by their match index and joins them into a single message
+ message_to_send = '\n'.join(map(lambda x: x[1], sorted(all_snippets)))
+
+ if 0 < len(message_to_send) <= 2000 and message_to_send.count('\n') <= 15:
+ await message.edit(suppress=True)
+ if len(message_to_send) > 1000 and message.channel.id != Channels.bot_commands:
+ # Redirects to #bot-commands if the snippet contents are too long
+ await self.bot.wait_until_guild_available()
+ await message.channel.send(('The snippet you tried to send was too long. Please '
+ f'see <#{Channels.bot_commands}> for the full snippet.'))
+ bot_commands_channel = self.bot.get_channel(Channels.bot_commands)
+ await wait_for_deletion(
+ await bot_commands_channel.send(message_to_send),
+ (message.author.id,)
+ )
+ else:
+ await wait_for_deletion(
+ await message.channel.send(message_to_send),
+ (message.author.id,)
+ )
+
+
+def setup(bot: Bot) -> None:
+ """Load the CodeSnippets cog."""
+ bot.add_cog(CodeSnippets(bot))
diff --git a/bot/exts/info/doc.py b/bot/exts/info/doc.py
deleted file mode 100644
index 9b5bd6504..000000000
--- a/bot/exts/info/doc.py
+++ /dev/null
@@ -1,485 +0,0 @@
-import asyncio
-import functools
-import logging
-import re
-import textwrap
-from contextlib import suppress
-from types import SimpleNamespace
-from typing import Optional, Tuple
-
-import discord
-from bs4 import BeautifulSoup
-from bs4.element import PageElement, Tag
-from discord.errors import NotFound
-from discord.ext import commands
-from markdownify import MarkdownConverter
-from requests import ConnectTimeout, ConnectionError, HTTPError
-from sphinx.ext import intersphinx
-from urllib3.exceptions import ProtocolError
-
-from bot.bot import Bot
-from bot.constants import MODERATION_ROLES, RedirectOutput
-from bot.converters import ValidPythonIdentifier, ValidURL
-from bot.pagination import LinePaginator
-from bot.utils.cache import AsyncCache
-from bot.utils.messages import wait_for_deletion
-
-
-log = logging.getLogger(__name__)
-logging.getLogger('urllib3').setLevel(logging.WARNING)
-
-# Since Intersphinx is intended to be used with Sphinx,
-# we need to mock its configuration.
-SPHINX_MOCK_APP = SimpleNamespace(
- config=SimpleNamespace(
- intersphinx_timeout=3,
- tls_verify=True,
- user_agent="python3:python-discord/bot:1.0.0"
- )
-)
-
-NO_OVERRIDE_GROUPS = (
- "2to3fixer",
- "token",
- "label",
- "pdbcommand",
- "term",
-)
-NO_OVERRIDE_PACKAGES = (
- "python",
-)
-
-SEARCH_END_TAG_ATTRS = (
- "data",
- "function",
- "class",
- "exception",
- "seealso",
- "section",
- "rubric",
- "sphinxsidebar",
-)
-UNWANTED_SIGNATURE_SYMBOLS_RE = re.compile(r"\[source]|\\\\|¶")
-WHITESPACE_AFTER_NEWLINES_RE = re.compile(r"(?<=\n\n)(\s+)")
-
-FAILED_REQUEST_RETRY_AMOUNT = 3
-NOT_FOUND_DELETE_DELAY = RedirectOutput.delete_delay
-
-symbol_cache = AsyncCache()
-
-
-class DocMarkdownConverter(MarkdownConverter):
- """Subclass markdownify's MarkdownCoverter to provide custom conversion methods."""
-
- def convert_code(self, el: PageElement, text: str) -> str:
- """Undo `markdownify`s underscore escaping."""
- return f"`{text}`".replace('\\', '')
-
- def convert_pre(self, el: PageElement, text: str) -> str:
- """Wrap any codeblocks in `py` for syntax highlighting."""
- code = ''.join(el.strings)
- return f"```py\n{code}```"
-
-
-def markdownify(html: str) -> DocMarkdownConverter:
- """Create a DocMarkdownConverter object from the input html."""
- return DocMarkdownConverter(bullets='•').convert(html)
-
-
-class InventoryURL(commands.Converter):
- """
- Represents an Intersphinx inventory URL.
-
- This converter checks whether intersphinx accepts the given inventory URL, and raises
- `BadArgument` if that is not the case.
-
- Otherwise, it simply passes through the given URL.
- """
-
- @staticmethod
- async def convert(ctx: commands.Context, url: str) -> str:
- """Convert url to Intersphinx inventory URL."""
- try:
- intersphinx.fetch_inventory(SPHINX_MOCK_APP, '', url)
- except AttributeError:
- raise commands.BadArgument(f"Failed to fetch Intersphinx inventory from URL `{url}`.")
- except ConnectionError:
- if url.startswith('https'):
- raise commands.BadArgument(
- f"Cannot establish a connection to `{url}`. Does it support HTTPS?"
- )
- raise commands.BadArgument(f"Cannot connect to host with URL `{url}`.")
- except ValueError:
- raise commands.BadArgument(
- f"Failed to read Intersphinx inventory from URL `{url}`. "
- "Are you sure that it's a valid inventory file?"
- )
- return url
-
-
-class Doc(commands.Cog):
- """A set of commands for querying & displaying documentation."""
-
- def __init__(self, bot: Bot):
- self.base_urls = {}
- self.bot = bot
- self.inventories = {}
- self.renamed_symbols = set()
-
- self.bot.loop.create_task(self.init_refresh_inventory())
-
- async def init_refresh_inventory(self) -> None:
- """Refresh documentation inventory on cog initialization."""
- await self.bot.wait_until_guild_available()
- await self.refresh_inventory()
-
- async def update_single(
- self, package_name: str, base_url: str, inventory_url: str
- ) -> None:
- """
- Rebuild the inventory for a single package.
-
- Where:
- * `package_name` is the package name to use, appears in the log
- * `base_url` is the root documentation URL for the specified package, used to build
- absolute paths that link to specific symbols
- * `inventory_url` is the absolute URL to the intersphinx inventory, fetched by running
- `intersphinx.fetch_inventory` in an executor on the bot's event loop
- """
- self.base_urls[package_name] = base_url
-
- package = await self._fetch_inventory(inventory_url)
- if not package:
- return None
-
- for group, value in package.items():
- for symbol, (package_name, _version, relative_doc_url, _) in value.items():
- absolute_doc_url = base_url + relative_doc_url
-
- if symbol in self.inventories:
- group_name = group.split(":")[1]
- symbol_base_url = self.inventories[symbol].split("/", 3)[2]
- if (
- group_name in NO_OVERRIDE_GROUPS
- or any(package in symbol_base_url for package in NO_OVERRIDE_PACKAGES)
- ):
-
- symbol = f"{group_name}.{symbol}"
- # If renamed `symbol` already exists, add library name in front to differentiate between them.
- if symbol in self.renamed_symbols:
- # Split `package_name` because of packages like Pillow that have spaces in them.
- symbol = f"{package_name.split()[0]}.{symbol}"
-
- self.inventories[symbol] = absolute_doc_url
- self.renamed_symbols.add(symbol)
- continue
-
- self.inventories[symbol] = absolute_doc_url
-
- log.trace(f"Fetched inventory for {package_name}.")
-
- async def refresh_inventory(self) -> None:
- """Refresh internal documentation inventory."""
- log.debug("Refreshing documentation inventory...")
-
- # Clear the old base URLS and inventories to ensure
- # that we start from a fresh local dataset.
- # Also, reset the cache used for fetching documentation.
- self.base_urls.clear()
- self.inventories.clear()
- self.renamed_symbols.clear()
- symbol_cache.clear()
-
- # Run all coroutines concurrently - since each of them performs a HTTP
- # request, this speeds up fetching the inventory data heavily.
- coros = [
- self.update_single(
- package["package"], package["base_url"], package["inventory_url"]
- ) for package in await self.bot.api_client.get('bot/documentation-links')
- ]
- await asyncio.gather(*coros)
-
- async def get_symbol_html(self, symbol: str) -> Optional[Tuple[list, str]]:
- """
- Given a Python symbol, return its signature and description.
-
- The first tuple element is the signature of the given symbol as a markup-free string, and
- the second tuple element is the description of the given symbol with HTML markup included.
-
- If the given symbol is a module, returns a tuple `(None, str)`
- else if the symbol could not be found, returns `None`.
- """
- url = self.inventories.get(symbol)
- if url is None:
- return None
-
- async with self.bot.http_session.get(url) as response:
- html = await response.text(encoding='utf-8')
-
- # Find the signature header and parse the relevant parts.
- symbol_id = url.split('#')[-1]
- soup = BeautifulSoup(html, 'lxml')
- symbol_heading = soup.find(id=symbol_id)
- search_html = str(soup)
-
- if symbol_heading is None:
- return None
-
- if symbol_id == f"module-{symbol}":
- # Get page content from the module headerlink to the
- # first tag that has its class in `SEARCH_END_TAG_ATTRS`
- start_tag = symbol_heading.find("a", attrs={"class": "headerlink"})
- if start_tag is None:
- return [], ""
-
- end_tag = start_tag.find_next(self._match_end_tag)
- if end_tag is None:
- return [], ""
-
- description_start_index = search_html.find(str(start_tag.parent)) + len(str(start_tag.parent))
- description_end_index = search_html.find(str(end_tag))
- description = search_html[description_start_index:description_end_index]
- signatures = None
-
- else:
- signatures = []
- description = str(symbol_heading.find_next_sibling("dd"))
- description_pos = search_html.find(description)
- # Get text of up to 3 signatures, remove unwanted symbols
- for element in [symbol_heading] + symbol_heading.find_next_siblings("dt", limit=2):
- signature = UNWANTED_SIGNATURE_SYMBOLS_RE.sub("", element.text)
- if signature and search_html.find(str(element)) < description_pos:
- signatures.append(signature)
-
- return signatures, description.replace('¶', '')
-
- @symbol_cache(arg_offset=1)
- async def get_symbol_embed(self, symbol: str) -> Optional[discord.Embed]:
- """
- Attempt to scrape and fetch the data for the given `symbol`, and build an embed from its contents.
-
- If the symbol is known, an Embed with documentation about it is returned.
- """
- scraped_html = await self.get_symbol_html(symbol)
- if scraped_html is None:
- return None
-
- signatures = scraped_html[0]
- permalink = self.inventories[symbol]
- description = markdownify(scraped_html[1])
-
- # Truncate the description of the embed to the last occurrence
- # of a double newline (interpreted as a paragraph) before index 1000.
- if len(description) > 1000:
- shortened = description[:1000]
- description_cutoff = shortened.rfind('\n\n', 100)
- if description_cutoff == -1:
- # Search the shortened version for cutoff points in decreasing desirability,
- # cutoff at 1000 if none are found.
- for string in (". ", ", ", ",", " "):
- description_cutoff = shortened.rfind(string)
- if description_cutoff != -1:
- break
- else:
- description_cutoff = 1000
- description = description[:description_cutoff]
-
- # If there is an incomplete code block, cut it out
- if description.count("```") % 2:
- codeblock_start = description.rfind('```py')
- description = description[:codeblock_start].rstrip()
- description += f"... [read more]({permalink})"
-
- description = WHITESPACE_AFTER_NEWLINES_RE.sub('', description)
- if signatures is None:
- # If symbol is a module, don't show signature.
- embed_description = description
-
- elif not signatures:
- # It's some "meta-page", for example:
- # https://docs.djangoproject.com/en/dev/ref/views/#module-django.views
- embed_description = "This appears to be a generic page not tied to a specific symbol."
-
- else:
- embed_description = "".join(f"```py\n{textwrap.shorten(signature, 500)}```" for signature in signatures)
- embed_description += f"\n{description}"
-
- embed = discord.Embed(
- title=f'`{symbol}`',
- url=permalink,
- description=embed_description
- )
- # Show all symbols with the same name that were renamed in the footer.
- embed.set_footer(
- text=", ".join(renamed for renamed in self.renamed_symbols - {symbol} if renamed.endswith(f".{symbol}"))
- )
- return embed
-
- @commands.group(name='docs', aliases=('doc', 'd'), invoke_without_command=True)
- async def docs_group(self, ctx: commands.Context, symbol: commands.clean_content = None) -> None:
- """Lookup documentation for Python symbols."""
- await self.get_command(ctx, symbol)
-
- @docs_group.command(name='get', aliases=('g',))
- async def get_command(self, ctx: commands.Context, symbol: commands.clean_content = None) -> None:
- """
- Return a documentation embed for a given symbol.
-
- If no symbol is given, return a list of all available inventories.
-
- Examples:
- !docs
- !docs aiohttp
- !docs aiohttp.ClientSession
- !docs get aiohttp.ClientSession
- """
- if symbol is None:
- inventory_embed = discord.Embed(
- title=f"All inventories (`{len(self.base_urls)}` total)",
- colour=discord.Colour.blue()
- )
-
- lines = sorted(f"• [`{name}`]({url})" for name, url in self.base_urls.items())
- if self.base_urls:
- await LinePaginator.paginate(lines, ctx, inventory_embed, max_size=400, empty=False)
-
- else:
- inventory_embed.description = "Hmmm, seems like there's nothing here yet."
- await ctx.send(embed=inventory_embed)
-
- else:
- # Fetching documentation for a symbol (at least for the first time, since
- # caching is used) takes quite some time, so let's send typing to indicate
- # that we got the command, but are still working on it.
- async with ctx.typing():
- doc_embed = await self.get_symbol_embed(symbol)
-
- if doc_embed is None:
- error_embed = discord.Embed(
- description=f"Sorry, I could not find any documentation for `{symbol}`.",
- colour=discord.Colour.red()
- )
- error_message = await ctx.send(embed=error_embed)
- with suppress(NotFound):
- await error_message.delete(delay=NOT_FOUND_DELETE_DELAY)
- await ctx.message.delete(delay=NOT_FOUND_DELETE_DELAY)
- else:
- msg = await ctx.send(embed=doc_embed)
- await wait_for_deletion(msg, (ctx.author.id,))
-
- @docs_group.command(name='set', aliases=('s',))
- @commands.has_any_role(*MODERATION_ROLES)
- async def set_command(
- self, ctx: commands.Context, package_name: ValidPythonIdentifier,
- base_url: ValidURL, inventory_url: InventoryURL
- ) -> None:
- """
- Adds a new documentation metadata object to the site's database.
-
- The database will update the object, should an existing item with the specified `package_name` already exist.
-
- Example:
- !docs set \
- python \
- https://docs.python.org/3/ \
- https://docs.python.org/3/objects.inv
- """
- body = {
- 'package': package_name,
- 'base_url': base_url,
- 'inventory_url': inventory_url
- }
- await self.bot.api_client.post('bot/documentation-links', json=body)
-
- log.info(
- f"User @{ctx.author} ({ctx.author.id}) added a new documentation package:\n"
- f"Package name: {package_name}\n"
- f"Base url: {base_url}\n"
- f"Inventory URL: {inventory_url}"
- )
-
- # Rebuilding the inventory can take some time, so lets send out a
- # typing event to show that the Bot is still working.
- async with ctx.typing():
- await self.refresh_inventory()
- await ctx.send(f"Added package `{package_name}` to database and refreshed inventory.")
-
- @docs_group.command(name='delete', aliases=('remove', 'rm', 'd'))
- @commands.has_any_role(*MODERATION_ROLES)
- async def delete_command(self, ctx: commands.Context, package_name: ValidPythonIdentifier) -> None:
- """
- Removes the specified package from the database.
-
- Examples:
- !docs delete aiohttp
- """
- await self.bot.api_client.delete(f'bot/documentation-links/{package_name}')
-
- async with ctx.typing():
- # Rebuild the inventory to ensure that everything
- # that was from this package is properly deleted.
- await self.refresh_inventory()
- await ctx.send(f"Successfully deleted `{package_name}` and refreshed inventory.")
-
- @docs_group.command(name="refresh", aliases=("rfsh", "r"))
- @commands.has_any_role(*MODERATION_ROLES)
- async def refresh_command(self, ctx: commands.Context) -> None:
- """Refresh inventories and send differences to channel."""
- old_inventories = set(self.base_urls)
- with ctx.typing():
- await self.refresh_inventory()
- # Get differences of added and removed inventories
- added = ', '.join(inv for inv in self.base_urls if inv not in old_inventories)
- if added:
- added = f"+ {added}"
-
- removed = ', '.join(inv for inv in old_inventories if inv not in self.base_urls)
- if removed:
- removed = f"- {removed}"
-
- embed = discord.Embed(
- title="Inventories refreshed",
- description=f"```diff\n{added}\n{removed}```" if added or removed else ""
- )
- await ctx.send(embed=embed)
-
- async def _fetch_inventory(self, inventory_url: str) -> Optional[dict]:
- """Get and return inventory from `inventory_url`. If fetching fails, return None."""
- fetch_func = functools.partial(intersphinx.fetch_inventory, SPHINX_MOCK_APP, '', inventory_url)
- for retry in range(1, FAILED_REQUEST_RETRY_AMOUNT+1):
- try:
- package = await self.bot.loop.run_in_executor(None, fetch_func)
- except ConnectTimeout:
- log.error(
- f"Fetching of inventory {inventory_url} timed out,"
- f" trying again. ({retry}/{FAILED_REQUEST_RETRY_AMOUNT})"
- )
- except ProtocolError:
- log.error(
- f"Connection lost while fetching inventory {inventory_url},"
- f" trying again. ({retry}/{FAILED_REQUEST_RETRY_AMOUNT})"
- )
- except HTTPError as e:
- log.error(f"Fetching of inventory {inventory_url} failed with status code {e.response.status_code}.")
- return None
- except ConnectionError:
- log.error(f"Couldn't establish connection to inventory {inventory_url}.")
- return None
- else:
- return package
- log.error(f"Fetching of inventory {inventory_url} failed.")
- return None
-
- @staticmethod
- def _match_end_tag(tag: Tag) -> bool:
- """Matches `tag` if its class value is in `SEARCH_END_TAG_ATTRS` or the tag is table."""
- for attr in SEARCH_END_TAG_ATTRS:
- if attr in tag.get("class", ()):
- return True
-
- return tag.name == "table"
-
-
-def setup(bot: Bot) -> None:
- """Load the Doc cog."""
- bot.add_cog(Doc(bot))
diff --git a/bot/exts/info/doc/__init__.py b/bot/exts/info/doc/__init__.py
new file mode 100644
index 000000000..38a8975c0
--- /dev/null
+++ b/bot/exts/info/doc/__init__.py
@@ -0,0 +1,16 @@
+from bot.bot import Bot
+from ._redis_cache import DocRedisCache
+
+MAX_SIGNATURE_AMOUNT = 3
+PRIORITY_PACKAGES = (
+ "python",
+)
+NAMESPACE = "doc"
+
+doc_cache = DocRedisCache(namespace=NAMESPACE)
+
+
+def setup(bot: Bot) -> None:
+ """Load the Doc cog."""
+ from ._cog import DocCog
+ bot.add_cog(DocCog(bot))
diff --git a/bot/exts/info/doc/_batch_parser.py b/bot/exts/info/doc/_batch_parser.py
new file mode 100644
index 000000000..369bb462c
--- /dev/null
+++ b/bot/exts/info/doc/_batch_parser.py
@@ -0,0 +1,186 @@
+from __future__ import annotations
+
+import asyncio
+import collections
+import logging
+from collections import defaultdict
+from contextlib import suppress
+from operator import attrgetter
+from typing import Deque, Dict, List, NamedTuple, Optional, Union
+
+import discord
+from bs4 import BeautifulSoup
+
+import bot
+from bot.constants import Channels
+from bot.utils import scheduling
+from . import _cog, doc_cache
+from ._parsing import get_symbol_markdown
+
+log = logging.getLogger(__name__)
+
+
+class StaleInventoryNotifier:
+ """Handle sending notifications about stale inventories through `DocItem`s to dev log."""
+
+ def __init__(self):
+ self._init_task = bot.instance.loop.create_task(
+ self._init_channel(),
+ name="StaleInventoryNotifier channel init"
+ )
+ self._warned_urls = set()
+
+ async def _init_channel(self) -> None:
+ """Wait for guild and get channel."""
+ await bot.instance.wait_until_guild_available()
+ self._dev_log = bot.instance.get_channel(Channels.dev_log)
+
+ async def send_warning(self, doc_item: _cog.DocItem) -> None:
+ """Send a warning to dev log if one wasn't already sent for `item`'s url."""
+ if doc_item.url not in self._warned_urls:
+ self._warned_urls.add(doc_item.url)
+ await self._init_task
+ embed = discord.Embed(
+ description=f"Doc item `{doc_item.symbol_id=}` present in loaded documentation inventories "
+ f"not found on [site]({doc_item.url}), inventories may need to be refreshed."
+ )
+ await self._dev_log.send(embed=embed)
+
+
+class QueueItem(NamedTuple):
+ """Contains a `DocItem` and the `BeautifulSoup` object needed to parse it."""
+
+ doc_item: _cog.DocItem
+ soup: BeautifulSoup
+
+ def __eq__(self, other: Union[QueueItem, _cog.DocItem]):
+ if isinstance(other, _cog.DocItem):
+ return self.doc_item == other
+ return NamedTuple.__eq__(self, other)
+
+
+class ParseResultFuture(asyncio.Future):
+ """
+ Future with metadata for the parser class.
+
+ `user_requested` is set by the parser when a Future is requested by an user and moved to the front,
+ allowing the futures to only be waited for when clearing if they were user requested.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.user_requested = False
+
+
+class BatchParser:
+ """
+ Get the Markdown of all symbols on a page and send them to redis when a symbol is requested.
+
+ DocItems are added through the `add_item` method which adds them to the `_page_doc_items` dict.
+ `get_markdown` is used to fetch the Markdown; when this is used for the first time on a page,
+ all of the symbols are queued to be parsed to avoid multiple web requests to the same page.
+ """
+
+ def __init__(self):
+ self._queue: Deque[QueueItem] = collections.deque()
+ self._page_doc_items: Dict[str, List[_cog.DocItem]] = defaultdict(list)
+ self._item_futures: Dict[_cog.DocItem, ParseResultFuture] = defaultdict(ParseResultFuture)
+ self._parse_task = None
+
+ self.stale_inventory_notifier = StaleInventoryNotifier()
+
+ async def get_markdown(self, doc_item: _cog.DocItem) -> Optional[str]:
+ """
+ Get the result Markdown of `doc_item`.
+
+ If no symbols were fetched from `doc_item`s page before,
+ the HTML has to be fetched and then all items from the page are put into the parse queue.
+
+ Not safe to run while `self.clear` is running.
+ """
+ if doc_item not in self._item_futures and doc_item not in self._queue:
+ self._item_futures[doc_item].user_requested = True
+
+ async with bot.instance.http_session.get(doc_item.url) as response:
+ soup = await bot.instance.loop.run_in_executor(
+ None,
+ BeautifulSoup,
+ await response.text(encoding="utf8"),
+ "lxml",
+ )
+
+ self._queue.extendleft(QueueItem(item, soup) for item in self._page_doc_items[doc_item.url])
+ log.debug(f"Added items from {doc_item.url} to the parse queue.")
+
+ if self._parse_task is None:
+ self._parse_task = scheduling.create_task(self._parse_queue(), name="Queue parse")
+ else:
+ self._item_futures[doc_item].user_requested = True
+ with suppress(ValueError):
+ # If the item is not in the queue then the item is already parsed or is being parsed
+ self._move_to_front(doc_item)
+ return await self._item_futures[doc_item]
+
+ async def _parse_queue(self) -> None:
+ """
+ Parse all items from the queue, setting their result Markdown on the futures and sending them to redis.
+
+ The coroutine will run as long as the queue is not empty, resetting `self._parse_task` to None when finished.
+ """
+ log.trace("Starting queue parsing.")
+ try:
+ while self._queue:
+ item, soup = self._queue.pop()
+ markdown = None
+
+ if (future := self._item_futures[item]).done():
+ # Some items are present in the inventories multiple times under different symbol names,
+ # if we already parsed an equal item, we can just skip it.
+ continue
+
+ try:
+ markdown = await bot.instance.loop.run_in_executor(None, get_symbol_markdown, soup, item)
+ if markdown is not None:
+ await doc_cache.set(item, markdown)
+ else:
+ # Don't wait for this coro as the parsing doesn't depend on anything it does.
+ scheduling.create_task(
+ self.stale_inventory_notifier.send_warning(item), name="Stale inventory warning"
+ )
+ except Exception:
+ log.exception(f"Unexpected error when handling {item}")
+ future.set_result(markdown)
+ del self._item_futures[item]
+ await asyncio.sleep(0.1)
+ finally:
+ self._parse_task = None
+ log.trace("Finished parsing queue.")
+
+ def _move_to_front(self, item: Union[QueueItem, _cog.DocItem]) -> None:
+ """Move `item` to the front of the parse queue."""
+ # The parse queue stores soups along with the doc symbols in QueueItem objects,
+ # in case we're moving a DocItem we have to get the associated QueueItem first and then move it.
+ item_index = self._queue.index(item)
+ queue_item = self._queue[item_index]
+ del self._queue[item_index]
+
+ self._queue.append(queue_item)
+ log.trace(f"Moved {item} to the front of the queue.")
+
+ def add_item(self, doc_item: _cog.DocItem) -> None:
+ """Map a DocItem to its page so that the symbol will be parsed once the page is requested."""
+ self._page_doc_items[doc_item.url].append(doc_item)
+
+ async def clear(self) -> None:
+ """
+ Clear all internal symbol data.
+
+ Wait for all user-requested symbols to be parsed before clearing the parser.
+ """
+ for future in filter(attrgetter("user_requested"), self._item_futures.values()):
+ await future
+ if self._parse_task is not None:
+ self._parse_task.cancel()
+ self._queue.clear()
+ self._page_doc_items.clear()
+ self._item_futures.clear()
diff --git a/bot/exts/info/doc/_cog.py b/bot/exts/info/doc/_cog.py
new file mode 100644
index 000000000..2a8016fb8
--- /dev/null
+++ b/bot/exts/info/doc/_cog.py
@@ -0,0 +1,442 @@
+from __future__ import annotations
+
+import asyncio
+import logging
+import sys
+import textwrap
+from collections import defaultdict
+from contextlib import suppress
+from types import SimpleNamespace
+from typing import Dict, NamedTuple, Optional, Tuple, Union
+
+import aiohttp
+import discord
+from discord.ext import commands
+
+from bot.bot import Bot
+from bot.constants import MODERATION_ROLES, RedirectOutput
+from bot.converters import Inventory, PackageName, ValidURL, allowed_strings
+from bot.pagination import LinePaginator
+from bot.utils.lock import SharedEvent, lock
+from bot.utils.messages import send_denial, wait_for_deletion
+from bot.utils.scheduling import Scheduler
+from . import NAMESPACE, PRIORITY_PACKAGES, _batch_parser, doc_cache
+from ._inventory_parser import InventoryDict, fetch_inventory
+
+log = logging.getLogger(__name__)
+
+# symbols with a group contained here will get the group prefixed on duplicates
+FORCE_PREFIX_GROUPS = (
+ "2to3fixer",
+ "token",
+ "label",
+ "pdbcommand",
+ "term",
+)
+NOT_FOUND_DELETE_DELAY = RedirectOutput.delete_delay
+# Delay to wait before trying to reach a rescheduled inventory again, in minutes
+FETCH_RESCHEDULE_DELAY = SimpleNamespace(first=2, repeated=5)
+
+COMMAND_LOCK_SINGLETON = "inventory refresh"
+
+
+class DocItem(NamedTuple):
+ """Holds inventory symbol information."""
+
+ package: str # Name of the package name the symbol is from
+ group: str # Interpshinx "role" of the symbol, for example `label` or `method`
+ base_url: str # Absolute path to to which the relative path resolves, same for all items with the same package
+ relative_url_path: str # Relative path to the page where the symbol is located
+ symbol_id: str # Fragment id used to locate the symbol on the page
+
+ @property
+ def url(self) -> str:
+ """Return the absolute url to the symbol."""
+ return self.base_url + self.relative_url_path
+
+
+class DocCog(commands.Cog):
+ """A set of commands for querying & displaying documentation."""
+
+ def __init__(self, bot: Bot):
+ # Contains URLs to documentation home pages.
+ # Used to calculate inventory diffs on refreshes and to display all currently stored inventories.
+ self.base_urls = {}
+ self.bot = bot
+ self.doc_symbols: Dict[str, DocItem] = {} # Maps symbol names to objects containing their metadata.
+ self.item_fetcher = _batch_parser.BatchParser()
+ # Maps a conflicting symbol name to a list of the new, disambiguated names created from conflicts with the name.
+ self.renamed_symbols = defaultdict(list)
+
+ self.inventory_scheduler = Scheduler(self.__class__.__name__)
+
+ self.refresh_event = asyncio.Event()
+ self.refresh_event.set()
+ self.symbol_get_event = SharedEvent()
+
+ self.init_refresh_task = self.bot.loop.create_task(
+ self.init_refresh_inventory(),
+ name="Doc inventory init"
+ )
+
+ @lock(NAMESPACE, COMMAND_LOCK_SINGLETON, raise_error=True)
+ async def init_refresh_inventory(self) -> None:
+ """Refresh documentation inventory on cog initialization."""
+ await self.bot.wait_until_guild_available()
+ await self.refresh_inventories()
+
+ def update_single(self, package_name: str, base_url: str, inventory: InventoryDict) -> None:
+ """
+ Build the inventory for a single package.
+
+ Where:
+ * `package_name` is the package name to use in logs and when qualifying symbols
+ * `base_url` is the root documentation URL for the specified package, used to build
+ absolute paths that link to specific symbols
+ * `package` is the content of a intersphinx inventory.
+ """
+ self.base_urls[package_name] = base_url
+
+ for group, items in inventory.items():
+ for symbol_name, relative_doc_url in items:
+
+ # e.g. get 'class' from 'py:class'
+ group_name = group.split(":")[1]
+ symbol_name = self.ensure_unique_symbol_name(
+ package_name,
+ group_name,
+ symbol_name,
+ )
+
+ relative_url_path, _, symbol_id = relative_doc_url.partition("#")
+ # Intern fields that have shared content so we're not storing unique strings for every object
+ doc_item = DocItem(
+ package_name,
+ sys.intern(group_name),
+ base_url,
+ sys.intern(relative_url_path),
+ symbol_id,
+ )
+ self.doc_symbols[symbol_name] = doc_item
+ self.item_fetcher.add_item(doc_item)
+
+ log.trace(f"Fetched inventory for {package_name}.")
+
+ async def update_or_reschedule_inventory(
+ self,
+ api_package_name: str,
+ base_url: str,
+ inventory_url: str,
+ ) -> None:
+ """
+ Update the cog's inventories, or reschedule this method to execute again if the remote inventory is unreachable.
+
+ The first attempt is rescheduled to execute in `FETCH_RESCHEDULE_DELAY.first` minutes, the subsequent attempts
+ in `FETCH_RESCHEDULE_DELAY.repeated` minutes.
+ """
+ package = await fetch_inventory(inventory_url)
+
+ if not package:
+ if api_package_name in self.inventory_scheduler:
+ self.inventory_scheduler.cancel(api_package_name)
+ delay = FETCH_RESCHEDULE_DELAY.repeated
+ else:
+ delay = FETCH_RESCHEDULE_DELAY.first
+ log.info(f"Failed to fetch inventory; attempting again in {delay} minutes.")
+ self.inventory_scheduler.schedule_later(
+ delay*60,
+ api_package_name,
+ self.update_or_reschedule_inventory(api_package_name, base_url, inventory_url),
+ )
+ else:
+ self.update_single(api_package_name, base_url, package)
+
+ def ensure_unique_symbol_name(self, package_name: str, group_name: str, symbol_name: str) -> str:
+ """
+ Ensure `symbol_name` doesn't overwrite an another symbol in `doc_symbols`.
+
+ For conflicts, rename either the current symbol or the existing symbol with which it conflicts.
+ Store the new name in `renamed_symbols` and return the name to use for the symbol.
+
+ If the existing symbol was renamed or there was no conflict, the returned name is equivalent to `symbol_name`.
+ """
+ if (item := self.doc_symbols.get(symbol_name)) is None:
+ return symbol_name # There's no conflict so it's fine to simply use the given symbol name.
+
+ def rename(prefix: str, *, rename_extant: bool = False) -> str:
+ new_name = f"{prefix}.{symbol_name}"
+ if new_name in self.doc_symbols:
+ # If there's still a conflict, qualify the name further.
+ if rename_extant:
+ new_name = f"{item.package}.{item.group}.{symbol_name}"
+ else:
+ new_name = f"{package_name}.{group_name}.{symbol_name}"
+
+ self.renamed_symbols[symbol_name].append(new_name)
+
+ if rename_extant:
+ # Instead of renaming the current symbol, rename the symbol with which it conflicts.
+ self.doc_symbols[new_name] = self.doc_symbols[symbol_name]
+ return symbol_name
+ else:
+ return new_name
+
+ # Certain groups are added as prefixes to disambiguate the symbols.
+ if group_name in FORCE_PREFIX_GROUPS:
+ return rename(group_name)
+
+ # The existing symbol with which the current symbol conflicts should have a group prefix.
+ # It currently doesn't have the group prefix because it's only added once there's a conflict.
+ elif item.group in FORCE_PREFIX_GROUPS:
+ return rename(item.group, rename_extant=True)
+
+ elif package_name in PRIORITY_PACKAGES:
+ return rename(item.package, rename_extant=True)
+
+ # If we can't specially handle the symbol through its group or package,
+ # fall back to prepending its package name to the front.
+ else:
+ return rename(package_name)
+
+ async def refresh_inventories(self) -> None:
+ """Refresh internal documentation inventories."""
+ self.refresh_event.clear()
+ await self.symbol_get_event.wait()
+ log.debug("Refreshing documentation inventory...")
+ self.inventory_scheduler.cancel_all()
+
+ self.base_urls.clear()
+ self.doc_symbols.clear()
+ self.renamed_symbols.clear()
+ await self.item_fetcher.clear()
+
+ coros = [
+ self.update_or_reschedule_inventory(
+ package["package"], package["base_url"], package["inventory_url"]
+ ) for package in await self.bot.api_client.get("bot/documentation-links")
+ ]
+ await asyncio.gather(*coros)
+ log.debug("Finished inventory refresh.")
+ self.refresh_event.set()
+
+ def get_symbol_item(self, symbol_name: str) -> Tuple[str, Optional[DocItem]]:
+ """
+ Get the `DocItem` and the symbol name used to fetch it from the `doc_symbols` dict.
+
+ If the doc item is not found directly from the passed in name and the name contains a space,
+ the first word of the name will be attempted to be used to get the item.
+ """
+ doc_item = self.doc_symbols.get(symbol_name)
+ if doc_item is None and " " in symbol_name:
+ symbol_name = symbol_name.split(" ", maxsplit=1)[0]
+ doc_item = self.doc_symbols.get(symbol_name)
+
+ return symbol_name, doc_item
+
+ async def get_symbol_markdown(self, doc_item: DocItem) -> str:
+ """
+ Get the Markdown from the symbol `doc_item` refers to.
+
+ First a redis lookup is attempted, if that fails the `item_fetcher`
+ is used to fetch the page and parse the HTML from it into Markdown.
+ """
+ markdown = await doc_cache.get(doc_item)
+
+ if markdown is None:
+ log.debug(f"Redis cache miss with {doc_item}.")
+ try:
+ markdown = await self.item_fetcher.get_markdown(doc_item)
+
+ except aiohttp.ClientError as e:
+ log.warning(f"A network error has occurred when requesting parsing of {doc_item}.", exc_info=e)
+ return "Unable to parse the requested symbol due to a network error."
+
+ except Exception:
+ log.exception(f"An unexpected error has occurred when requesting parsing of {doc_item}.")
+ return "Unable to parse the requested symbol due to an error."
+
+ if markdown is None:
+ return "Unable to parse the requested symbol."
+ return markdown
+
+ async def create_symbol_embed(self, symbol_name: str) -> Optional[discord.Embed]:
+ """
+ Attempt to scrape and fetch the data for the given `symbol_name`, and build an embed from its contents.
+
+ If the symbol is known, an Embed with documentation about it is returned.
+
+ First check the DocRedisCache before querying the cog's `BatchParser`.
+ """
+ log.trace(f"Building embed for symbol `{symbol_name}`")
+ if not self.refresh_event.is_set():
+ log.debug("Waiting for inventories to be refreshed before processing item.")
+ await self.refresh_event.wait()
+ # Ensure a refresh can't run in case of a context switch until the with block is exited
+ with self.symbol_get_event:
+ symbol_name, doc_item = self.get_symbol_item(symbol_name)
+ if doc_item is None:
+ log.debug("Symbol does not exist.")
+ return None
+
+ self.bot.stats.incr(f"doc_fetches.{doc_item.package}")
+
+ # Show all symbols with the same name that were renamed in the footer,
+ # with a max of 200 chars.
+ if symbol_name in self.renamed_symbols:
+ renamed_symbols = ", ".join(self.renamed_symbols[symbol_name])
+ footer_text = textwrap.shorten("Similar names: " + renamed_symbols, 200, placeholder=" ...")
+ else:
+ footer_text = ""
+
+ embed = discord.Embed(
+ title=discord.utils.escape_markdown(symbol_name),
+ url=f"{doc_item.url}#{doc_item.symbol_id}",
+ description=await self.get_symbol_markdown(doc_item)
+ )
+ embed.set_footer(text=footer_text)
+ return embed
+
+ @commands.group(name="docs", aliases=("doc", "d"), invoke_without_command=True)
+ async def docs_group(self, ctx: commands.Context, *, symbol_name: Optional[str]) -> None:
+ """Look up documentation for Python symbols."""
+ await self.get_command(ctx, symbol_name=symbol_name)
+
+ @docs_group.command(name="getdoc", aliases=("g",))
+ async def get_command(self, ctx: commands.Context, *, symbol_name: Optional[str]) -> None:
+ """
+ Return a documentation embed for a given symbol.
+
+ If no symbol is given, return a list of all available inventories.
+
+ Examples:
+ !docs
+ !docs aiohttp
+ !docs aiohttp.ClientSession
+ !docs getdoc aiohttp.ClientSession
+ """
+ if not symbol_name:
+ inventory_embed = discord.Embed(
+ title=f"All inventories (`{len(self.base_urls)}` total)",
+ colour=discord.Colour.blue()
+ )
+
+ lines = sorted(f"• [`{name}`]({url})" for name, url in self.base_urls.items())
+ if self.base_urls:
+ await LinePaginator.paginate(lines, ctx, inventory_embed, max_size=400, empty=False)
+
+ else:
+ inventory_embed.description = "Hmmm, seems like there's nothing here yet."
+ await ctx.send(embed=inventory_embed)
+
+ else:
+ symbol = symbol_name.strip("`")
+ async with ctx.typing():
+ doc_embed = await self.create_symbol_embed(symbol)
+
+ if doc_embed is None:
+ error_message = await send_denial(ctx, "No documentation found for the requested symbol.")
+ await wait_for_deletion(error_message, (ctx.author.id,), timeout=NOT_FOUND_DELETE_DELAY)
+ with suppress(discord.NotFound):
+ await ctx.message.delete()
+ with suppress(discord.NotFound):
+ await error_message.delete()
+ else:
+ msg = await ctx.send(embed=doc_embed)
+ await wait_for_deletion(msg, (ctx.author.id,))
+
+ @docs_group.command(name="setdoc", aliases=("s",))
+ @commands.has_any_role(*MODERATION_ROLES)
+ @lock(NAMESPACE, COMMAND_LOCK_SINGLETON, raise_error=True)
+ async def set_command(
+ self,
+ ctx: commands.Context,
+ package_name: PackageName,
+ base_url: ValidURL,
+ inventory: Inventory,
+ ) -> None:
+ """
+ Adds a new documentation metadata object to the site's database.
+
+ The database will update the object, should an existing item with the specified `package_name` already exist.
+
+ Example:
+ !docs setdoc \
+ python \
+ https://docs.python.org/3/ \
+ https://docs.python.org/3/objects.inv
+ """
+ if not base_url.endswith("/"):
+ raise commands.BadArgument("The base url must end with a slash.")
+ inventory_url, inventory_dict = inventory
+ body = {
+ "package": package_name,
+ "base_url": base_url,
+ "inventory_url": inventory_url
+ }
+ await self.bot.api_client.post("bot/documentation-links", json=body)
+
+ log.info(
+ f"User @{ctx.author} ({ctx.author.id}) added a new documentation package:\n"
+ + "\n".join(f"{key}: {value}" for key, value in body.items())
+ )
+
+ self.update_single(package_name, base_url, inventory_dict)
+ await ctx.send(f"Added the package `{package_name}` to the database and updated the inventories.")
+
+ @docs_group.command(name="deletedoc", aliases=("removedoc", "rm", "d"))
+ @commands.has_any_role(*MODERATION_ROLES)
+ @lock(NAMESPACE, COMMAND_LOCK_SINGLETON, raise_error=True)
+ async def delete_command(self, ctx: commands.Context, package_name: PackageName) -> None:
+ """
+ Removes the specified package from the database.
+
+ Example:
+ !docs deletedoc aiohttp
+ """
+ await self.bot.api_client.delete(f"bot/documentation-links/{package_name}")
+
+ async with ctx.typing():
+ await self.refresh_inventories()
+ await doc_cache.delete(package_name)
+ await ctx.send(f"Successfully deleted `{package_name}` and refreshed the inventories.")
+
+ @docs_group.command(name="refreshdoc", aliases=("rfsh", "r"))
+ @commands.has_any_role(*MODERATION_ROLES)
+ @lock(NAMESPACE, COMMAND_LOCK_SINGLETON, raise_error=True)
+ async def refresh_command(self, ctx: commands.Context) -> None:
+ """Refresh inventories and show the difference."""
+ old_inventories = set(self.base_urls)
+ with ctx.typing():
+ await self.refresh_inventories()
+ new_inventories = set(self.base_urls)
+
+ if added := ", ".join(new_inventories - old_inventories):
+ added = "+ " + added
+
+ if removed := ", ".join(old_inventories - new_inventories):
+ removed = "- " + removed
+
+ embed = discord.Embed(
+ title="Inventories refreshed",
+ description=f"```diff\n{added}\n{removed}```" if added or removed else ""
+ )
+ await ctx.send(embed=embed)
+
+ @docs_group.command(name="cleardoccache", aliases=("deletedoccache",))
+ @commands.has_any_role(*MODERATION_ROLES)
+ async def clear_cache_command(
+ self,
+ ctx: commands.Context,
+ package_name: Union[PackageName, allowed_strings("*")] # noqa: F722
+ ) -> None:
+ """Clear the persistent redis cache for `package`."""
+ if await doc_cache.delete(package_name):
+ await ctx.send(f"Successfully cleared the cache for `{package_name}`.")
+ else:
+ await ctx.send("No keys matching the package found.")
+
+ def cog_unload(self) -> None:
+ """Clear scheduled inventories, queued symbols and cleanup task on cog unload."""
+ self.inventory_scheduler.cancel_all()
+ self.init_refresh_task.cancel()
+ asyncio.create_task(self.item_fetcher.clear(), name="DocCog.item_fetcher unload clear")
diff --git a/bot/exts/info/doc/_html.py b/bot/exts/info/doc/_html.py
new file mode 100644
index 000000000..94efd81b7
--- /dev/null
+++ b/bot/exts/info/doc/_html.py
@@ -0,0 +1,136 @@
+import logging
+import re
+from functools import partial
+from typing import Callable, Container, Iterable, List, Union
+
+from bs4 import BeautifulSoup
+from bs4.element import NavigableString, PageElement, SoupStrainer, Tag
+
+from . import MAX_SIGNATURE_AMOUNT
+
+log = logging.getLogger(__name__)
+
+_UNWANTED_SIGNATURE_SYMBOLS_RE = re.compile(r"\[source]|\\\\|¶")
+_SEARCH_END_TAG_ATTRS = (
+ "data",
+ "function",
+ "class",
+ "exception",
+ "seealso",
+ "section",
+ "rubric",
+ "sphinxsidebar",
+)
+
+
+class Strainer(SoupStrainer):
+ """Subclass of SoupStrainer to allow matching of both `Tag`s and `NavigableString`s."""
+
+ def __init__(self, *, include_strings: bool, **kwargs):
+ self.include_strings = include_strings
+ passed_text = kwargs.pop("text", None)
+ if passed_text is not None:
+ log.warning("`text` is not a supported kwarg in the custom strainer.")
+ super().__init__(**kwargs)
+
+ Markup = Union[PageElement, List["Markup"]]
+
+ def search(self, markup: Markup) -> Union[PageElement, str]:
+ """Extend default SoupStrainer behaviour to allow matching both `Tag`s` and `NavigableString`s."""
+ if isinstance(markup, str):
+ # Let everything through the text filter if we're including strings and tags.
+ if not self.name and not self.attrs and self.include_strings:
+ return markup
+ else:
+ return super().search(markup)
+
+
+def _find_elements_until_tag(
+ start_element: PageElement,
+ end_tag_filter: Union[Container[str], Callable[[Tag], bool]],
+ *,
+ func: Callable,
+ include_strings: bool = False,
+ limit: int = None,
+) -> List[Union[Tag, NavigableString]]:
+ """
+ Get all elements up to `limit` or until a tag matching `end_tag_filter` is found.
+
+ `end_tag_filter` can be either a container of string names to check against,
+ or a filtering callable that's applied to tags.
+
+ When `include_strings` is True, `NavigableString`s from the document will be included in the result along `Tag`s.
+
+ `func` takes in a BeautifulSoup unbound method for finding multiple elements, such as `BeautifulSoup.find_all`.
+ The method is then iterated over and all elements until the matching tag or the limit are added to the return list.
+ """
+ use_container_filter = not callable(end_tag_filter)
+ elements = []
+
+ for element in func(start_element, name=Strainer(include_strings=include_strings), limit=limit):
+ if isinstance(element, Tag):
+ if use_container_filter:
+ if element.name in end_tag_filter:
+ break
+ elif end_tag_filter(element):
+ break
+ elements.append(element)
+
+ return elements
+
+
+_find_next_children_until_tag = partial(_find_elements_until_tag, func=partial(BeautifulSoup.find_all, recursive=False))
+_find_recursive_children_until_tag = partial(_find_elements_until_tag, func=BeautifulSoup.find_all)
+_find_next_siblings_until_tag = partial(_find_elements_until_tag, func=BeautifulSoup.find_next_siblings)
+_find_previous_siblings_until_tag = partial(_find_elements_until_tag, func=BeautifulSoup.find_previous_siblings)
+
+
+def _class_filter_factory(class_names: Iterable[str]) -> Callable[[Tag], bool]:
+ """Create callable that returns True when the passed in tag's class is in `class_names` or when it's a table."""
+ def match_tag(tag: Tag) -> bool:
+ for attr in class_names:
+ if attr in tag.get("class", ()):
+ return True
+ return tag.name == "table"
+
+ return match_tag
+
+
+def get_general_description(start_element: Tag) -> List[Union[Tag, NavigableString]]:
+ """
+ Get page content to a table or a tag with its class in `SEARCH_END_TAG_ATTRS`.
+
+ A headerlink tag is attempted to be found to skip repeating the symbol information in the description.
+ If it's found it's used as the tag to start the search from instead of the `start_element`.
+ """
+ child_tags = _find_recursive_children_until_tag(start_element, _class_filter_factory(["section"]), limit=100)
+ header = next(filter(_class_filter_factory(["headerlink"]), child_tags), None)
+ start_tag = header.parent if header is not None else start_element
+ return _find_next_siblings_until_tag(start_tag, _class_filter_factory(_SEARCH_END_TAG_ATTRS), include_strings=True)
+
+
+def get_dd_description(symbol: PageElement) -> List[Union[Tag, NavigableString]]:
+ """Get the contents of the next dd tag, up to a dt or a dl tag."""
+ description_tag = symbol.find_next("dd")
+ return _find_next_children_until_tag(description_tag, ("dt", "dl"), include_strings=True)
+
+
+def get_signatures(start_signature: PageElement) -> List[str]:
+ """
+ Collect up to `_MAX_SIGNATURE_AMOUNT` signatures from dt tags around the `start_signature` dt tag.
+
+ First the signatures under the `start_signature` are included;
+ if less than 2 are found, tags above the start signature are added to the result if any are present.
+ """
+ signatures = []
+ for element in (
+ *reversed(_find_previous_siblings_until_tag(start_signature, ("dd",), limit=2)),
+ start_signature,
+ *_find_next_siblings_until_tag(start_signature, ("dd",), limit=2),
+ )[-MAX_SIGNATURE_AMOUNT:]:
+ signature = _UNWANTED_SIGNATURE_SYMBOLS_RE.sub("", element.text)
+
+ if signature:
+ signatures.append(signature)
+
+ return signatures
diff --git a/bot/exts/info/doc/_inventory_parser.py b/bot/exts/info/doc/_inventory_parser.py
new file mode 100644
index 000000000..80d5841a0
--- /dev/null
+++ b/bot/exts/info/doc/_inventory_parser.py
@@ -0,0 +1,126 @@
+import logging
+import re
+import zlib
+from collections import defaultdict
+from typing import AsyncIterator, DefaultDict, List, Optional, Tuple
+
+import aiohttp
+
+import bot
+
+log = logging.getLogger(__name__)
+
+FAILED_REQUEST_ATTEMPTS = 3
+_V2_LINE_RE = re.compile(r'(?x)(.+?)\s+(\S*:\S*)\s+(-?\d+)\s+?(\S*)\s+(.*)')
+
+InventoryDict = DefaultDict[str, List[Tuple[str, str]]]
+
+
+class ZlibStreamReader:
+ """Class used for decoding zlib data of a stream line by line."""
+
+ READ_CHUNK_SIZE = 16 * 1024
+
+ def __init__(self, stream: aiohttp.StreamReader) -> None:
+ self.stream = stream
+
+ async def _read_compressed_chunks(self) -> AsyncIterator[bytes]:
+ """Read zlib data in `READ_CHUNK_SIZE` sized chunks and decompress."""
+ decompressor = zlib.decompressobj()
+ async for chunk in self.stream.iter_chunked(self.READ_CHUNK_SIZE):
+ yield decompressor.decompress(chunk)
+
+ yield decompressor.flush()
+
+ async def __aiter__(self) -> AsyncIterator[str]:
+ """Yield lines of decompressed text."""
+ buf = b''
+ async for chunk in self._read_compressed_chunks():
+ buf += chunk
+ pos = buf.find(b'\n')
+ while pos != -1:
+ yield buf[:pos].decode()
+ buf = buf[pos + 1:]
+ pos = buf.find(b'\n')
+
+
+async def _load_v1(stream: aiohttp.StreamReader) -> InventoryDict:
+ invdata = defaultdict(list)
+
+ async for line in stream:
+ name, type_, location = line.decode().rstrip().split(maxsplit=2)
+ # version 1 did not add anchors to the location
+ if type_ == "mod":
+ type_ = "py:module"
+ location += "#module-" + name
+ else:
+ type_ = "py:" + type_
+ location += "#" + name
+ invdata[type_].append((name, location))
+ return invdata
+
+
+async def _load_v2(stream: aiohttp.StreamReader) -> InventoryDict:
+ invdata = defaultdict(list)
+
+ async for line in ZlibStreamReader(stream):
+ m = _V2_LINE_RE.match(line.rstrip())
+ name, type_, _prio, location, _dispname = m.groups() # ignore the parsed items we don't need
+ if location.endswith("$"):
+ location = location[:-1] + name
+
+ invdata[type_].append((name, location))
+ return invdata
+
+
+async def _fetch_inventory(url: str) -> InventoryDict:
+ """Fetch, parse and return an intersphinx inventory file from an url."""
+ timeout = aiohttp.ClientTimeout(sock_connect=5, sock_read=5)
+ async with bot.instance.http_session.get(url, timeout=timeout, raise_for_status=True) as response:
+ stream = response.content
+
+ inventory_header = (await stream.readline()).decode().rstrip()
+ inventory_version = int(inventory_header[-1:])
+ await stream.readline() # skip project name
+ await stream.readline() # skip project version
+
+ if inventory_version == 1:
+ return await _load_v1(stream)
+
+ elif inventory_version == 2:
+ if b"zlib" not in await stream.readline():
+ raise ValueError(f"Invalid inventory file at url {url}.")
+ return await _load_v2(stream)
+
+ raise ValueError(f"Invalid inventory file at url {url}.")
+
+
+async def fetch_inventory(url: str) -> Optional[InventoryDict]:
+ """
+ Get an inventory dict from `url`, retrying `FAILED_REQUEST_ATTEMPTS` times on errors.
+
+ `url` should point at a valid sphinx objects.inv inventory file, which will be parsed into the
+ inventory dict in the format of {"domain:role": [("symbol_name", "relative_url_to_symbol"), ...], ...}
+ """
+ for attempt in range(1, FAILED_REQUEST_ATTEMPTS+1):
+ try:
+ inventory = await _fetch_inventory(url)
+ except aiohttp.ClientConnectorError:
+ log.warning(
+ f"Failed to connect to inventory url at {url}; "
+ f"trying again ({attempt}/{FAILED_REQUEST_ATTEMPTS})."
+ )
+ except aiohttp.ClientError:
+ log.error(
+ f"Failed to get inventory from {url}; "
+ f"trying again ({attempt}/{FAILED_REQUEST_ATTEMPTS})."
+ )
+ except Exception:
+ log.exception(
+ f"An unexpected error has occurred during fetching of {url}; "
+ f"trying again ({attempt}/{FAILED_REQUEST_ATTEMPTS})."
+ )
+ else:
+ return inventory
+
+ return None
diff --git a/bot/exts/info/doc/_markdown.py b/bot/exts/info/doc/_markdown.py
new file mode 100644
index 000000000..1b7d8232b
--- /dev/null
+++ b/bot/exts/info/doc/_markdown.py
@@ -0,0 +1,58 @@
+from urllib.parse import urljoin
+
+from bs4.element import PageElement
+from markdownify import MarkdownConverter
+
+
+class DocMarkdownConverter(MarkdownConverter):
+ """Subclass markdownify's MarkdownCoverter to provide custom conversion methods."""
+
+ def __init__(self, *, page_url: str, **options):
+ super().__init__(**options)
+ self.page_url = page_url
+
+ def convert_li(self, el: PageElement, text: str, convert_as_inline: bool) -> str:
+ """Fix markdownify's erroneous indexing in ol tags."""
+ parent = el.parent
+ if parent is not None and parent.name == "ol":
+ li_tags = parent.find_all("li")
+ bullet = f"{li_tags.index(el)+1}."
+ else:
+ depth = -1
+ while el:
+ if el.name == "ul":
+ depth += 1
+ el = el.parent
+ bullets = self.options["bullets"]
+ bullet = bullets[depth % len(bullets)]
+ return f"{bullet} {text}\n"
+
+ def convert_hn(self, _n: int, el: PageElement, text: str, convert_as_inline: bool) -> str:
+ """Convert h tags to bold text with ** instead of adding #."""
+ if convert_as_inline:
+ return text
+ return f"**{text}**\n\n"
+
+ def convert_code(self, el: PageElement, text: str, convert_as_inline: bool) -> str:
+ """Undo `markdownify`s underscore escaping."""
+ return f"`{text}`".replace("\\", "")
+
+ def convert_pre(self, el: PageElement, text: str, convert_as_inline: bool) -> str:
+ """Wrap any codeblocks in `py` for syntax highlighting."""
+ code = "".join(el.strings)
+ return f"```py\n{code}```"
+
+ def convert_a(self, el: PageElement, text: str, convert_as_inline: bool) -> str:
+ """Resolve relative URLs to `self.page_url`."""
+ el["href"] = urljoin(self.page_url, el["href"])
+ return super().convert_a(el, text, convert_as_inline)
+
+ def convert_p(self, el: PageElement, text: str, convert_as_inline: bool) -> str:
+ """Include only one newline instead of two when the parent is a li tag."""
+ if convert_as_inline:
+ return text
+
+ parent = el.parent
+ if parent is not None and parent.name == "li":
+ return f"{text}\n"
+ return super().convert_p(el, text, convert_as_inline)
diff --git a/bot/exts/info/doc/_parsing.py b/bot/exts/info/doc/_parsing.py
new file mode 100644
index 000000000..bf840b96f
--- /dev/null
+++ b/bot/exts/info/doc/_parsing.py
@@ -0,0 +1,256 @@
+from __future__ import annotations
+
+import logging
+import re
+import string
+import textwrap
+from collections import namedtuple
+from typing import Collection, Iterable, Iterator, List, Optional, TYPE_CHECKING, Union
+
+from bs4 import BeautifulSoup
+from bs4.element import NavigableString, Tag
+
+from bot.utils.helpers import find_nth_occurrence
+from . import MAX_SIGNATURE_AMOUNT
+from ._html import get_dd_description, get_general_description, get_signatures
+from ._markdown import DocMarkdownConverter
+if TYPE_CHECKING:
+ from ._cog import DocItem
+
+log = logging.getLogger(__name__)
+
+_WHITESPACE_AFTER_NEWLINES_RE = re.compile(r"(?<=\n\n)(\s+)")
+_PARAMETERS_RE = re.compile(r"\((.+)\)")
+
+_NO_SIGNATURE_GROUPS = {
+ "attribute",
+ "envvar",
+ "setting",
+ "tempaltefilter",
+ "templatetag",
+ "term",
+}
+_EMBED_CODE_BLOCK_LINE_LENGTH = 61
+# _MAX_SIGNATURE_AMOUNT code block wrapped lines with py syntax highlight
+_MAX_SIGNATURES_LENGTH = (_EMBED_CODE_BLOCK_LINE_LENGTH + 8) * MAX_SIGNATURE_AMOUNT
+# Maximum embed description length - signatures on top
+_MAX_DESCRIPTION_LENGTH = 2048 - _MAX_SIGNATURES_LENGTH
+_TRUNCATE_STRIP_CHARACTERS = "!?:;." + string.whitespace
+
+BracketPair = namedtuple("BracketPair", ["opening_bracket", "closing_bracket"])
+_BRACKET_PAIRS = {
+ "{": BracketPair("{", "}"),
+ "(": BracketPair("(", ")"),
+ "[": BracketPair("[", "]"),
+ "<": BracketPair("<", ">"),
+}
+
+
+def _split_parameters(parameters_string: str) -> Iterator[str]:
+ """
+ Split parameters of a signature into individual parameter strings on commas.
+
+ Long string literals are not accounted for.
+ """
+ last_split = 0
+ depth = 0
+ current_search: Optional[BracketPair] = None
+
+ enumerated_string = enumerate(parameters_string)
+ for index, character in enumerated_string:
+ if character in {"'", '"'}:
+ # Skip everything inside of strings, regardless of the depth.
+ quote_character = character # The closing quote must equal the opening quote.
+ preceding_backslashes = 0
+ for _, character in enumerated_string:
+ # If an odd number of backslashes precedes the quote, it was escaped.
+ if character == quote_character and not preceding_backslashes % 2:
+ break
+ if character == "\\":
+ preceding_backslashes += 1
+ else:
+ preceding_backslashes = 0
+
+ elif current_search is None:
+ if (current_search := _BRACKET_PAIRS.get(character)) is not None:
+ depth = 1
+ elif character == ",":
+ yield parameters_string[last_split:index]
+ last_split = index + 1
+
+ else:
+ if character == current_search.opening_bracket:
+ depth += 1
+
+ elif character == current_search.closing_bracket:
+ depth -= 1
+ if depth == 0:
+ current_search = None
+
+ yield parameters_string[last_split:]
+
+
+def _truncate_signatures(signatures: Collection[str]) -> Union[List[str], Collection[str]]:
+ """
+ Truncate passed signatures to not exceed `_MAX_SIGNATURES_LENGTH`.
+
+ If the signatures need to be truncated, parameters are collapsed until they fit withing the limit.
+ Individual signatures can consist of max 1, 2, ..., `_MAX_SIGNATURE_AMOUNT` lines of text,
+ inversely proportional to the amount of signatures.
+ A maximum of `_MAX_SIGNATURE_AMOUNT` signatures is assumed to be passed.
+ """
+ if sum(len(signature) for signature in signatures) <= _MAX_SIGNATURES_LENGTH:
+ # Total length of signatures is under the length limit; no truncation needed.
+ return signatures
+
+ max_signature_length = _EMBED_CODE_BLOCK_LINE_LENGTH * (MAX_SIGNATURE_AMOUNT + 1 - len(signatures))
+ formatted_signatures = []
+ for signature in signatures:
+ signature = signature.strip()
+ if len(signature) > max_signature_length:
+ if (parameters_match := _PARAMETERS_RE.search(signature)) is None:
+ # The signature has no parameters or the regex failed; perform a simple truncation of the text.
+ formatted_signatures.append(textwrap.shorten(signature, max_signature_length, placeholder="..."))
+ continue
+
+ truncated_signature = []
+ parameters_string = parameters_match[1]
+ running_length = len(signature) - len(parameters_string)
+ for parameter in _split_parameters(parameters_string):
+ # Check if including this parameter would still be within the maximum length.
+ if (len(parameter) + running_length) <= max_signature_length - 5: # account for comma and placeholder
+ truncated_signature.append(parameter)
+ running_length += len(parameter) + 1
+ else:
+ # There's no more room for this parameter. Truncate the parameter list and put it in the signature.
+ truncated_signature.append(" ...")
+ formatted_signatures.append(signature.replace(parameters_string, ",".join(truncated_signature)))
+ break
+ else:
+ # The current signature is under the length limit; no truncation needed.
+ formatted_signatures.append(signature)
+
+ return formatted_signatures
+
+
+def _get_truncated_description(
+ elements: Iterable[Union[Tag, NavigableString]],
+ markdown_converter: DocMarkdownConverter,
+ max_length: int,
+ max_lines: int,
+) -> str:
+ """
+ Truncate the Markdown from `elements` to be at most `max_length` characters when rendered or `max_lines` newlines.
+
+ `max_length` limits the length of the rendered characters in the string,
+ with the real string length limited to `_MAX_DESCRIPTION_LENGTH` to accommodate discord length limits.
+ """
+ result = ""
+ markdown_element_ends = [] # Stores indices into `result` which point to the end boundary of each Markdown element.
+ rendered_length = 0
+
+ tag_end_index = 0
+ for element in elements:
+ is_tag = isinstance(element, Tag)
+ element_length = len(element.text) if is_tag else len(element)
+
+ if rendered_length + element_length < max_length:
+ if is_tag:
+ element_markdown = markdown_converter.process_tag(element, convert_as_inline=False)
+ else:
+ element_markdown = markdown_converter.process_text(element)
+
+ rendered_length += element_length
+ tag_end_index += len(element_markdown)
+
+ if not element_markdown.isspace():
+ markdown_element_ends.append(tag_end_index)
+ result += element_markdown
+ else:
+ break
+
+ if not markdown_element_ends:
+ return ""
+
+ # Determine the "hard" truncation index. Account for the ellipsis placeholder for the max length.
+ newline_truncate_index = find_nth_occurrence(result, "\n", max_lines)
+ if newline_truncate_index is not None and newline_truncate_index < _MAX_DESCRIPTION_LENGTH - 3:
+ # Truncate based on maximum lines if there are more than the maximum number of lines.
+ truncate_index = newline_truncate_index
+ else:
+ # There are less than the maximum number of lines; truncate based on the max char length.
+ truncate_index = _MAX_DESCRIPTION_LENGTH - 3
+
+ # Nothing needs to be truncated if the last element ends before the truncation index.
+ if truncate_index >= markdown_element_ends[-1]:
+ return result
+
+ # Determine the actual truncation index.
+ possible_truncation_indices = [cut for cut in markdown_element_ends if cut < truncate_index]
+ if not possible_truncation_indices:
+ # In case there is no Markdown element ending before the truncation index, try to find a good cutoff point.
+ force_truncated = result[:truncate_index]
+ # If there is an incomplete codeblock, cut it out.
+ if force_truncated.count("```") % 2:
+ force_truncated = force_truncated[:force_truncated.rfind("```")]
+ # Search for substrings to truncate at, with decreasing desirability.
+ for string_ in ("\n\n", "\n", ". ", ", ", ",", " "):
+ cutoff = force_truncated.rfind(string_)
+
+ if cutoff != -1:
+ truncated_result = force_truncated[:cutoff]
+ break
+ else:
+ truncated_result = force_truncated
+
+ else:
+ # Truncate at the last Markdown element that comes before the truncation index.
+ markdown_truncate_index = possible_truncation_indices[-1]
+ truncated_result = result[:markdown_truncate_index]
+
+ return truncated_result.strip(_TRUNCATE_STRIP_CHARACTERS) + "..."
+
+
+def _create_markdown(signatures: Optional[List[str]], description: Iterable[Tag], url: str) -> str:
+ """
+ Create a Markdown string with the signatures at the top, and the converted html description below them.
+
+ The signatures are wrapped in python codeblocks, separated from the description by a newline.
+ The result Markdown string is max 750 rendered characters for the description with signatures at the start.
+ """
+ description = _get_truncated_description(
+ description,
+ markdown_converter=DocMarkdownConverter(bullets="•", page_url=url),
+ max_length=750,
+ max_lines=13
+ )
+ description = _WHITESPACE_AFTER_NEWLINES_RE.sub("", description)
+ if signatures is not None:
+ signature = "".join(f"```py\n{signature}```" for signature in _truncate_signatures(signatures))
+ return f"{signature}\n{description}"
+ else:
+ return description
+
+
+def get_symbol_markdown(soup: BeautifulSoup, symbol_data: DocItem) -> Optional[str]:
+ """
+ Return parsed Markdown of the passed item using the passed in soup, truncated to fit within a discord message.
+
+ The method of parsing and what information gets included depends on the symbol's group.
+ """
+ symbol_heading = soup.find(id=symbol_data.symbol_id)
+ if symbol_heading is None:
+ return None
+ signature = None
+ # Modules, doc pages and labels don't point to description list tags but to tags like divs,
+ # no special parsing can be done so we only try to include what's under them.
+ if symbol_heading.name != "dt":
+ description = get_general_description(symbol_heading)
+
+ elif symbol_data.group in _NO_SIGNATURE_GROUPS:
+ description = get_dd_description(symbol_heading)
+
+ else:
+ signature = get_signatures(symbol_heading)
+ description = get_dd_description(symbol_heading)
+ return _create_markdown(signature, description, symbol_data.url).replace("¶", "").strip()
diff --git a/bot/exts/info/doc/_redis_cache.py b/bot/exts/info/doc/_redis_cache.py
new file mode 100644
index 000000000..ad764816f
--- /dev/null
+++ b/bot/exts/info/doc/_redis_cache.py
@@ -0,0 +1,70 @@
+from __future__ import annotations
+
+import datetime
+from typing import Optional, TYPE_CHECKING
+
+from async_rediscache.types.base import RedisObject, namespace_lock
+if TYPE_CHECKING:
+ from ._cog import DocItem
+
+WEEK_SECONDS = datetime.timedelta(weeks=1).total_seconds()
+
+
+class DocRedisCache(RedisObject):
+ """Interface for redis functionality needed by the Doc cog."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._set_expires = set()
+
+ @namespace_lock
+ async def set(self, item: DocItem, value: str) -> None:
+ """
+ Set the Markdown `value` for the symbol `item`.
+
+ All keys from a single page are stored together, expiring a week after the first set.
+ """
+ url_key = remove_suffix(item.relative_url_path, ".html")
+ redis_key = f"{self.namespace}:{item.package}:{url_key}"
+ needs_expire = False
+
+ with await self._get_pool_connection() as connection:
+ if redis_key not in self._set_expires:
+ # An expire is only set if the key didn't exist before.
+ # If this is the first time setting values for this key check if it exists and add it to
+ # `_set_expires` to prevent redundant checks for subsequent uses with items from the same page.
+ self._set_expires.add(redis_key)
+ needs_expire = not await connection.exists(redis_key)
+
+ await connection.hset(redis_key, item.symbol_id, value)
+ if needs_expire:
+ await connection.expire(redis_key, WEEK_SECONDS)
+
+ @namespace_lock
+ async def get(self, item: DocItem) -> Optional[str]:
+ """Return the Markdown content of the symbol `item` if it exists."""
+ url_key = remove_suffix(item.relative_url_path, ".html")
+
+ with await self._get_pool_connection() as connection:
+ return await connection.hget(f"{self.namespace}:{item.package}:{url_key}", item.symbol_id, encoding="utf8")
+
+ @namespace_lock
+ async def delete(self, package: str) -> bool:
+ """Remove all values for `package`; return True if at least one key was deleted, False otherwise."""
+ with await self._get_pool_connection() as connection:
+ package_keys = [
+ package_key async for package_key in connection.iscan(match=f"{self.namespace}:{package}:*")
+ ]
+ if package_keys:
+ await connection.delete(*package_keys)
+ return True
+ return False
+
+
+def remove_suffix(string: str, suffix: str) -> str:
+ """Remove `suffix` from end of `string`."""
+ # TODO replace usages with str.removesuffix on 3.9
+ if string.endswith(suffix):
+ return string[:-len(suffix)]
+ else:
+ return string
diff --git a/bot/exts/info/information.py b/bot/exts/info/information.py
index 5e2c4b417..834fee1b4 100644
--- a/bot/exts/info/information.py
+++ b/bot/exts/info/information.py
@@ -230,6 +230,11 @@ class Information(Cog):
if on_server and user.nick:
name = f"{user.nick} ({name})"
+ if user.public_flags.verified_bot:
+ name += f" {constants.Emojis.verified_bot}"
+ elif user.bot:
+ name += f" {constants.Emojis.bot}"
+
badges = []
for badge, is_set in user.public_flags:
diff --git a/bot/exts/info/source.py b/bot/exts/info/source.py
index 49e74f204..ef07c77a1 100644
--- a/bot/exts/info/source.py
+++ b/bot/exts/info/source.py
@@ -14,9 +14,10 @@ SourceType = Union[commands.HelpCommand, commands.Command, commands.Cog, str, co
class SourceConverter(commands.Converter):
"""Convert an argument into a help command, tag, command, or cog."""
- async def convert(self, ctx: commands.Context, argument: str) -> SourceType:
+ @staticmethod
+ async def convert(ctx: commands.Context, argument: str) -> SourceType:
"""Convert argument into source object."""
- if argument.lower().startswith("help"):
+ if argument.lower() == "help":
return ctx.bot.help_command
cog = ctx.bot.get_cog(argument)
@@ -68,7 +69,8 @@ class BotSource(commands.Cog):
Raise BadArgument if `source_item` is a dynamically-created object (e.g. via internal eval).
"""
if isinstance(source_item, commands.Command):
- src = source_item.callback.__code__
+ source_item = inspect.unwrap(source_item.callback)
+ src = source_item.__code__
filename = src.co_filename
elif isinstance(source_item, str):
tags_cog = self.bot.get_cog("Tags")
diff --git a/bot/exts/moderation/infraction/infractions.py b/bot/exts/moderation/infraction/infractions.py
index d89e80acc..38d1ffc0e 100644
--- a/bot/exts/moderation/infraction/infractions.py
+++ b/bot/exts/moderation/infraction/infractions.py
@@ -54,8 +54,12 @@ class Infractions(InfractionScheduler, commands.Cog):
# region: Permanent infractions
@command()
- async def warn(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None:
+ async def warn(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None:
"""Warn a user for the given reason."""
+ if not isinstance(user, Member):
+ await ctx.send(":x: The user doesn't appear to be on the server.")
+ return
+
infraction = await _utils.post_infraction(ctx, user, "warning", reason, active=False)
if infraction is None:
return
@@ -63,8 +67,12 @@ class Infractions(InfractionScheduler, commands.Cog):
await self.apply_infraction(ctx, infraction, user)
@command()
- async def kick(self, ctx: Context, user: Member, *, reason: t.Optional[str] = None) -> None:
+ async def kick(self, ctx: Context, user: FetchedMember, *, reason: t.Optional[str] = None) -> None:
"""Kick a user for the given reason."""
+ if not isinstance(user, Member):
+ await ctx.send(":x: The user doesn't appear to be on the server.")
+ return
+
await self.apply_kick(ctx, user, reason)
@command()
@@ -100,7 +108,7 @@ class Infractions(InfractionScheduler, commands.Cog):
@command(aliases=["mute"])
async def tempmute(
self, ctx: Context,
- user: Member,
+ user: FetchedMember,
duration: t.Optional[Expiry] = None,
*,
reason: t.Optional[str] = None
@@ -122,6 +130,10 @@ class Infractions(InfractionScheduler, commands.Cog):
If no duration is given, a one hour duration is used by default.
"""
+ if not isinstance(user, Member):
+ await ctx.send(":x: The user doesn't appear to be on the server.")
+ return
+
if duration is None:
duration = await Duration().convert(ctx, "1h")
await self.apply_mute(ctx, user, reason, expires_at=duration)
diff --git a/bot/exts/moderation/infraction/superstarify.py b/bot/exts/moderation/infraction/superstarify.py
index 704dddf9c..07e79b9fe 100644
--- a/bot/exts/moderation/infraction/superstarify.py
+++ b/bot/exts/moderation/infraction/superstarify.py
@@ -11,7 +11,7 @@ from discord.utils import escape_markdown
from bot import constants
from bot.bot import Bot
-from bot.converters import Expiry
+from bot.converters import Duration, Expiry
from bot.exts.moderation.infraction import _utils
from bot.exts.moderation.infraction._scheduler import InfractionScheduler
from bot.utils.messages import format_user
@@ -19,6 +19,7 @@ from bot.utils.time import format_infraction
log = logging.getLogger(__name__)
NICKNAME_POLICY_URL = "https://pythondiscord.com/pages/rules/#nickname-policy"
+SUPERSTARIFY_DEFAULT_DURATION = "1h"
with Path("bot/resources/stars.json").open(encoding="utf-8") as stars_file:
STAR_NAMES = json.load(stars_file)
@@ -109,7 +110,7 @@ class Superstarify(InfractionScheduler, Cog):
self,
ctx: Context,
member: Member,
- duration: Expiry,
+ duration: t.Optional[Expiry],
*,
reason: str = '',
) -> None:
@@ -134,6 +135,9 @@ class Superstarify(InfractionScheduler, Cog):
if await _utils.get_active_infraction(ctx, member, "superstar"):
return
+ # Set to default duration if none was provided.
+ duration = duration or await Duration().convert(ctx, SUPERSTARIFY_DEFAULT_DURATION)
+
# Post the infraction to the API
old_nick = member.display_name
infraction_reason = f'Old nickname: {old_nick}. {reason}'
diff --git a/bot/exts/moderation/modlog.py b/bot/exts/moderation/modlog.py
index 2dae9d268..e92f76c9a 100644
--- a/bot/exts/moderation/modlog.py
+++ b/bot/exts/moderation/modlog.py
@@ -14,7 +14,7 @@ from discord.abc import GuildChannel
from discord.ext.commands import Cog, Context
from bot.bot import Bot
-from bot.constants import Categories, Channels, Colours, Emojis, Event, Guild as GuildConstant, Icons, URLs
+from bot.constants import Categories, Channels, Colours, Emojis, Event, Guild as GuildConstant, Icons, Roles, URLs
from bot.utils.messages import format_user
from bot.utils.time import humanize_delta
@@ -115,9 +115,9 @@ class ModLog(Cog, name="ModLog"):
if ping_everyone:
if content:
- content = f"@everyone\n{content}"
+ content = f"<@&{Roles.moderators}>\n{content}"
else:
- content = "@everyone"
+ content = f"<@&{Roles.moderators}>"
# Truncate content to 2000 characters and append an ellipsis.
if content and len(content) > 2000:
@@ -127,8 +127,7 @@ class ModLog(Cog, name="ModLog"):
log_message = await channel.send(
content=content,
embed=embed,
- files=files,
- allowed_mentions=discord.AllowedMentions(everyone=True)
+ files=files
)
if additional_embeds:
diff --git a/bot/exts/moderation/modpings.py b/bot/exts/moderation/modpings.py
new file mode 100644
index 000000000..1ad5005de
--- /dev/null
+++ b/bot/exts/moderation/modpings.py
@@ -0,0 +1,138 @@
+import datetime
+import logging
+
+from async_rediscache import RedisCache
+from dateutil.parser import isoparse
+from discord import Embed, Member
+from discord.ext.commands import Cog, Context, group, has_any_role
+
+from bot.bot import Bot
+from bot.constants import Colours, Emojis, Guild, Icons, MODERATION_ROLES, Roles
+from bot.converters import Expiry
+from bot.utils.scheduling import Scheduler
+
+log = logging.getLogger(__name__)
+
+
+class ModPings(Cog):
+ """Commands for a moderator to turn moderator pings on and off."""
+
+ # RedisCache[discord.Member.id, 'Naïve ISO 8601 string']
+ # The cache's keys are mods who have pings off.
+ # The cache's values are the times when the role should be re-applied to them, stored in ISO format.
+ pings_off_mods = RedisCache()
+
+ def __init__(self, bot: Bot):
+ self.bot = bot
+ self._role_scheduler = Scheduler(self.__class__.__name__)
+
+ self.guild = None
+ self.moderators_role = None
+
+ self.reschedule_task = self.bot.loop.create_task(self.reschedule_roles(), name="mod-pings-reschedule")
+
+ async def reschedule_roles(self) -> None:
+ """Reschedule moderators role re-apply times."""
+ await self.bot.wait_until_guild_available()
+ self.guild = self.bot.get_guild(Guild.id)
+ self.moderators_role = self.guild.get_role(Roles.moderators)
+
+ mod_team = self.guild.get_role(Roles.mod_team)
+ pings_on = self.moderators_role.members
+ pings_off = await self.pings_off_mods.to_dict()
+
+ log.trace("Applying the moderators role to the mod team where necessary.")
+ for mod in mod_team.members:
+ if mod in pings_on: # Make sure that on-duty mods aren't in the cache.
+ if mod in pings_off:
+ await self.pings_off_mods.delete(mod.id)
+ continue
+
+ # Keep the role off only for those in the cache.
+ if mod.id not in pings_off:
+ await self.reapply_role(mod)
+ else:
+ expiry = isoparse(pings_off[mod.id]).replace(tzinfo=None)
+ self._role_scheduler.schedule_at(expiry, mod.id, self.reapply_role(mod))
+
+ async def reapply_role(self, mod: Member) -> None:
+ """Reapply the moderator's role to the given moderator."""
+ log.trace(f"Re-applying role to mod with ID {mod.id}.")
+ await mod.add_roles(self.moderators_role, reason="Pings off period expired.")
+
+ @group(name='modpings', aliases=('modping',), invoke_without_command=True)
+ @has_any_role(*MODERATION_ROLES)
+ async def modpings_group(self, ctx: Context) -> None:
+ """Allow the removal and re-addition of the pingable moderators role."""
+ await ctx.send_help(ctx.command)
+
+ @modpings_group.command(name='off')
+ @has_any_role(*MODERATION_ROLES)
+ async def off_command(self, ctx: Context, duration: Expiry) -> None:
+ """
+ Temporarily removes the pingable moderators role for a set amount of time.
+
+ A unit of time should be appended to the duration.
+ Units (∗case-sensitive):
+ \u2003`y` - years
+ \u2003`m` - months∗
+ \u2003`w` - weeks
+ \u2003`d` - days
+ \u2003`h` - hours
+ \u2003`M` - minutes∗
+ \u2003`s` - seconds
+
+ Alternatively, an ISO 8601 timestamp can be provided for the duration.
+
+ The duration cannot be longer than 30 days.
+ """
+ duration: datetime.datetime
+ delta = duration - datetime.datetime.utcnow()
+ if delta > datetime.timedelta(days=30):
+ await ctx.send(":x: Cannot remove the role for longer than 30 days.")
+ return
+
+ mod = ctx.author
+
+ until_date = duration.replace(microsecond=0).isoformat() # Looks noisy with microseconds.
+ await mod.remove_roles(self.moderators_role, reason=f"Turned pings off until {until_date}.")
+
+ await self.pings_off_mods.set(mod.id, duration.isoformat())
+
+ # Allow rescheduling the task without cancelling it separately via the `on` command.
+ if mod.id in self._role_scheduler:
+ self._role_scheduler.cancel(mod.id)
+ self._role_scheduler.schedule_at(duration, mod.id, self.reapply_role(mod))
+
+ embed = Embed(timestamp=duration, colour=Colours.bright_green)
+ embed.set_footer(text="Moderators role has been removed until", icon_url=Icons.green_checkmark)
+ await ctx.send(embed=embed)
+
+ @modpings_group.command(name='on')
+ @has_any_role(*MODERATION_ROLES)
+ async def on_command(self, ctx: Context) -> None:
+ """Re-apply the pingable moderators role."""
+ mod = ctx.author
+ if mod in self.moderators_role.members:
+ await ctx.send(":question: You already have the role.")
+ return
+
+ await mod.add_roles(self.moderators_role, reason="Pings off period canceled.")
+
+ await self.pings_off_mods.delete(mod.id)
+
+ # We assume the task exists. Lack of it may indicate a bug.
+ self._role_scheduler.cancel(mod.id)
+
+ await ctx.send(f"{Emojis.check_mark} Moderators role has been re-applied.")
+
+ def cog_unload(self) -> None:
+ """Cancel role tasks when the cog unloads."""
+ log.trace("Cog unload: canceling role tasks.")
+ self.reschedule_task.cancel()
+ self._role_scheduler.cancel_all()
+
+
+def setup(bot: Bot) -> None:
+ """Load the ModPings cog."""
+ bot.add_cog(ModPings(bot))
diff --git a/bot/exts/moderation/stream.py b/bot/exts/moderation/stream.py
index 12e195172..fd856a7f4 100644
--- a/bot/exts/moderation/stream.py
+++ b/bot/exts/moderation/stream.py
@@ -1,5 +1,6 @@
import logging
from datetime import timedelta, timezone
+from operator import itemgetter
import arrow
import discord
@@ -8,8 +9,9 @@ from async_rediscache import RedisCache
from discord.ext import commands
from bot.bot import Bot
-from bot.constants import Colours, Emojis, Guild, Roles, STAFF_ROLES, VideoPermission
+from bot.constants import Colours, Emojis, Guild, MODERATION_ROLES, Roles, STAFF_ROLES, VideoPermission
from bot.converters import Expiry
+from bot.pagination import LinePaginator
from bot.utils.scheduling import Scheduler
from bot.utils.time import format_infraction_with_duration
@@ -68,8 +70,30 @@ class Stream(commands.Cog):
self._revoke_streaming_permission(member)
)
+ async def _suspend_stream(self, ctx: commands.Context, member: discord.Member) -> None:
+ """Suspend a member's stream."""
+ await self.bot.wait_until_guild_available()
+ voice_state = member.voice
+
+ if not voice_state:
+ return
+
+ # If the user is streaming.
+ if voice_state.self_stream:
+ # End user's stream by moving them to AFK voice channel and back.
+ original_vc = voice_state.channel
+ await member.move_to(ctx.guild.afk_channel)
+ await member.move_to(original_vc)
+
+ # Notify.
+ await ctx.send(f"{member.mention}'s stream has been suspended!")
+ log.debug(f"Successfully suspended stream from {member} ({member.id}).")
+ return
+
+ log.debug(f"No stream found to suspend from {member} ({member.id}).")
+
@commands.command(aliases=("streaming",))
- @commands.has_any_role(*STAFF_ROLES)
+ @commands.has_any_role(*MODERATION_ROLES)
async def stream(self, ctx: commands.Context, member: discord.Member, duration: Expiry = None) -> None:
"""
Temporarily grant streaming permissions to a member for a given duration.
@@ -126,7 +150,7 @@ class Stream(commands.Cog):
log.debug(f"Successfully gave {member} ({member.id}) permission to stream until {revoke_time}.")
@commands.command(aliases=("pstream",))
- @commands.has_any_role(*STAFF_ROLES)
+ @commands.has_any_role(*MODERATION_ROLES)
async def permanentstream(self, ctx: commands.Context, member: discord.Member) -> None:
"""Permanently grants the given member the permission to stream."""
log.trace(f"Attempting to give permanent streaming permission to {member} ({member.id}).")
@@ -153,7 +177,7 @@ class Stream(commands.Cog):
log.debug(f"Successfully gave {member} ({member.id}) permanent streaming permission.")
@commands.command(aliases=("unstream", "rstream"))
- @commands.has_any_role(*STAFF_ROLES)
+ @commands.has_any_role(*MODERATION_ROLES)
async def revokestream(self, ctx: commands.Context, member: discord.Member) -> None:
"""Revoke the permission to stream from the given member."""
log.trace(f"Attempting to remove streaming permission from {member} ({member.id}).")
@@ -168,10 +192,52 @@ class Stream(commands.Cog):
await ctx.send(f"{Emojis.check_mark} Revoked the permission to stream from {member.mention}.")
log.debug(f"Successfully revoked streaming permission from {member} ({member.id}).")
- return
- await ctx.send(f"{Emojis.cross_mark} This member doesn't have video permissions to remove!")
- log.debug(f"{member} ({member.id}) didn't have the streaming permission to remove!")
+ else:
+ await ctx.send(f"{Emojis.cross_mark} This member doesn't have video permissions to remove!")
+ log.debug(f"{member} ({member.id}) didn't have the streaming permission to remove!")
+
+ await self._suspend_stream(ctx, member)
+
+ @commands.command(aliases=('lstream',))
+ @commands.has_any_role(*MODERATION_ROLES)
+ async def liststream(self, ctx: commands.Context) -> None:
+ """Lists all non-staff users who have permission to stream."""
+ non_staff_members_with_stream = [
+ member
+ for member in ctx.guild.get_role(Roles.video).members
+ if not any(role.id in STAFF_ROLES for role in member.roles)
+ ]
+
+ # List of tuples (UtcPosixTimestamp, str)
+ # So that the list can be sorted on the UtcPosixTimestamp before the message is passed to the paginator.
+ streamer_info = []
+ for member in non_staff_members_with_stream:
+ if revoke_time := await self.task_cache.get(member.id):
+ # Member only has temporary streaming perms
+ revoke_delta = Arrow.utcfromtimestamp(revoke_time).humanize()
+ message = f"{member.mention} will have stream permissions revoked {revoke_delta}."
+ else:
+ message = f"{member.mention} has permanent streaming permissions."
+
+ # If revoke_time is None use max timestamp to force sort to put them at the end
+ streamer_info.append(
+ (revoke_time or Arrow.max.timestamp(), message)
+ )
+
+ if streamer_info:
+ # Sort based on duration left of streaming perms
+ streamer_info.sort(key=itemgetter(0))
+
+ # Only output the message in the pagination
+ lines = [line[1] for line in streamer_info]
+ embed = discord.Embed(
+ title=f"Members with streaming permission (`{len(lines)}` total)",
+ colour=Colours.soft_green
+ )
+ await LinePaginator.paginate(lines, ctx, embed, max_size=400, empty=False)
+ else:
+ await ctx.send("No members with stream permissions found.")
def setup(bot: Bot) -> None:
diff --git a/bot/exts/utils/clean.py b/bot/exts/utils/clean.py
index 8acaf9131..cb662e852 100644
--- a/bot/exts/utils/clean.py
+++ b/bot/exts/utils/clean.py
@@ -3,7 +3,7 @@ import random
import re
from typing import Iterable, Optional
-from discord import Colour, Embed, Message, TextChannel, User
+from discord import Colour, Embed, Message, TextChannel, User, errors
from discord.ext import commands
from discord.ext.commands import Cog, Context, group, has_any_role
@@ -115,7 +115,11 @@ class Clean(Cog):
# Delete the invocation first
self.mod_log.ignore(Event.message_delete, ctx.message.id)
- await ctx.message.delete()
+ try:
+ await ctx.message.delete()
+ except errors.NotFound:
+ # Invocation message has already been deleted
+ log.info("Tried to delete invocation message, but it was already deleted.")
messages = []
message_ids = []
diff --git a/bot/exts/utils/reminders.py b/bot/exts/utils/reminders.py
index 3113a1149..6c21920a1 100644
--- a/bot/exts/utils/reminders.py
+++ b/bot/exts/utils/reminders.py
@@ -90,15 +90,18 @@ class Reminders(Cog):
delivery_dt: t.Optional[datetime],
) -> None:
"""Send an embed confirming the reminder change was made successfully."""
- embed = discord.Embed()
- embed.colour = discord.Colour.green()
- embed.title = random.choice(POSITIVE_REPLIES)
- embed.description = on_success
+ embed = discord.Embed(
+ description=on_success,
+ colour=discord.Colour.green(),
+ title=random.choice(POSITIVE_REPLIES)
+ )
footer_str = f"ID: {reminder_id}"
+
if delivery_dt:
# Reminder deletion will have a `None` `delivery_dt`
- footer_str = f"{footer_str}, Due: {delivery_dt.strftime('%Y-%m-%dT%H:%M:%S')}"
+ footer_str += ', Due'
+ embed.timestamp = delivery_dt
embed.set_footer(text=footer_str)
diff --git a/bot/exts/utils/snekbox.py b/bot/exts/utils/snekbox.py
index 9f480c067..da95240bb 100644
--- a/bot/exts/utils/snekbox.py
+++ b/bot/exts/utils/snekbox.py
@@ -13,7 +13,7 @@ from discord.ext.commands import Cog, Context, command, guild_only
from bot.bot import Bot
from bot.constants import Categories, Channels, Roles, URLs
-from bot.decorators import in_whitelist
+from bot.decorators import not_in_blacklist
from bot.utils import send_to_paste_service
from bot.utils.messages import wait_for_deletion
@@ -38,9 +38,9 @@ RAW_CODE_REGEX = re.compile(
MAX_PASTE_LEN = 10000
-# `!eval` command whitelists
-EVAL_CHANNELS = (Channels.bot_commands, Channels.esoteric)
-EVAL_CATEGORIES = (Categories.help_available, Categories.help_in_use, Categories.voice)
+# `!eval` command whitelists and blacklists.
+NO_EVAL_CHANNELS = (Channels.python_general,)
+NO_EVAL_CATEGORIES = ()
EVAL_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners)
SIGKILL = 9
@@ -280,7 +280,7 @@ class Snekbox(Cog):
@command(name="eval", aliases=("e",))
@guild_only()
- @in_whitelist(channels=EVAL_CHANNELS, categories=EVAL_CATEGORIES, roles=EVAL_ROLES)
+ @not_in_blacklist(channels=NO_EVAL_CHANNELS, categories=NO_EVAL_CATEGORIES, override_roles=EVAL_ROLES)
async def eval_command(self, ctx: Context, *, code: str = None) -> None:
"""
Run Python code and get the results.
diff --git a/bot/exts/utils/utils.py b/bot/exts/utils/utils.py
index cae7f2593..4c39a7c2a 100644
--- a/bot/exts/utils/utils.py
+++ b/bot/exts/utils/utils.py
@@ -109,7 +109,7 @@ class Utils(Cog):
# handle if it's an index int
if isinstance(search_value, int):
upper_bound = len(zen_lines) - 1
- lower_bound = -1 * upper_bound
+ lower_bound = -1 * len(zen_lines)
if not (lower_bound <= search_value <= upper_bound):
raise BadArgument(f"Please provide an index between {lower_bound} and {upper_bound}.")
@@ -162,17 +162,27 @@ class Utils(Cog):
if len(snowflakes) > 1 and await has_no_roles_check(ctx, *STAFF_ROLES):
raise BadArgument("Cannot process more than one snowflake in one invocation.")
+ if not snowflakes:
+ raise BadArgument("At least one snowflake must be provided.")
+
+ embed = Embed(colour=Colour.blue())
+ embed.set_author(
+ name=f"Snowflake{'s'[:len(snowflakes)^1]}", # Deals with pluralisation
+ icon_url="https://github.com/twitter/twemoji/blob/master/assets/72x72/2744.png?raw=true"
+ )
+
+ lines = []
for snowflake in snowflakes:
created_at = snowflake_time(snowflake)
- embed = Embed(
- description=f"**Created at {created_at}** ({time_since(created_at, max_units=3)}).",
- colour=Colour.blue()
- )
- embed.set_author(
- name=f"Snowflake: {snowflake}",
- icon_url="https://github.com/twitter/twemoji/blob/master/assets/72x72/2744.png?raw=true"
- )
- await ctx.send(embed=embed)
+ lines.append(f"**{snowflake}**\nCreated at {created_at} ({time_since(created_at, max_units=3)}).")
+
+ await LinePaginator.paginate(
+ lines,
+ ctx=ctx,
+ embed=embed,
+ max_lines=5,
+ max_size=1000
+ )
@command(aliases=("poll",))
@has_any_role(*MODERATION_ROLES, Roles.project_leads, Roles.domain_leads)
diff --git a/bot/log.py b/bot/log.py
index e92233a33..4e20c005e 100644
--- a/bot/log.py
+++ b/bot/log.py
@@ -20,7 +20,6 @@ def setup() -> None:
logging.addLevelName(TRACE_LEVEL, "TRACE")
Logger.trace = _monkeypatch_trace
- log_level = TRACE_LEVEL if constants.DEBUG_MODE else logging.INFO
format_string = "%(asctime)s | %(name)s | %(levelname)s | %(message)s"
log_format = logging.Formatter(format_string)
@@ -30,7 +29,6 @@ def setup() -> None:
file_handler.setFormatter(log_format)
root_log = logging.getLogger()
- root_log.setLevel(log_level)
root_log.addHandler(file_handler)
if "COLOREDLOGS_LEVEL_STYLES" not in os.environ:
@@ -44,11 +42,9 @@ def setup() -> None:
if "COLOREDLOGS_LOG_FORMAT" not in os.environ:
coloredlogs.DEFAULT_LOG_FORMAT = format_string
- if "COLOREDLOGS_LOG_LEVEL" not in os.environ:
- coloredlogs.DEFAULT_LOG_LEVEL = log_level
-
- coloredlogs.install(logger=root_log, stream=sys.stdout)
+ coloredlogs.install(level=logging.TRACE, logger=root_log, stream=sys.stdout)
+ root_log.setLevel(logging.DEBUG if constants.DEBUG_MODE else logging.INFO)
logging.getLogger("discord").setLevel(logging.WARNING)
logging.getLogger("websockets").setLevel(logging.WARNING)
logging.getLogger("chardet").setLevel(logging.WARNING)
@@ -57,6 +53,8 @@ def setup() -> None:
# Set back to the default of INFO even if asyncio's debug mode is enabled.
logging.getLogger("asyncio").setLevel(logging.INFO)
+ _set_trace_loggers()
+
def setup_sentry() -> None:
"""Set up the Sentry logging integrations."""
@@ -86,3 +84,30 @@ def _monkeypatch_trace(self: logging.Logger, msg: str, *args, **kwargs) -> None:
"""
if self.isEnabledFor(TRACE_LEVEL):
self._log(TRACE_LEVEL, msg, args, **kwargs)
+
+
+def _set_trace_loggers() -> None:
+ """
+ Set loggers to the trace level according to the value from the BOT_TRACE_LOGGERS env var.
+
+ When the env var is a list of logger names delimited by a comma,
+ each of the listed loggers will be set to the trace level.
+
+ If this list is prefixed with a "!", all of the loggers except the listed ones will be set to the trace level.
+
+ Otherwise if the env var begins with a "*",
+ the root logger is set to the trace level and other contents are ignored.
+ """
+ level_filter = constants.Bot.trace_loggers
+ if level_filter:
+ if level_filter.startswith("*"):
+ logging.getLogger().setLevel(logging.TRACE)
+
+ elif level_filter.startswith("!"):
+ logging.getLogger().setLevel(logging.TRACE)
+ for logger_name in level_filter.strip("!,").split(","):
+ logging.getLogger(logger_name).setLevel(logging.DEBUG)
+
+ else:
+ for logger_name in level_filter.strip(",").split(","):
+ logging.getLogger(logger_name).setLevel(logging.TRACE)
diff --git a/bot/pagination.py b/bot/pagination.py
index 3b16cc9ff..c5c84afd9 100644
--- a/bot/pagination.py
+++ b/bot/pagination.py
@@ -2,14 +2,14 @@ import asyncio
import logging
import typing as t
from contextlib import suppress
+from functools import partial
import discord
-from discord import Member
from discord.abc import User
from discord.ext.commands import Context, Paginator
from bot import constants
-from bot.constants import MODERATION_ROLES
+from bot.utils import messages
FIRST_EMOJI = "\u23EE" # [:track_previous:]
LEFT_EMOJI = "\u2B05" # [:arrow_left:]
@@ -220,29 +220,6 @@ class LinePaginator(Paginator):
>>> embed.set_author(name="Some Operation", url=url, icon_url=icon)
>>> await LinePaginator.paginate([line for line in lines], ctx, embed)
"""
- def event_check(reaction_: discord.Reaction, user_: discord.Member) -> bool:
- """Make sure that this reaction is what we want to operate on."""
- no_restrictions = (
- # The reaction was by a whitelisted user
- user_.id == restrict_to_user.id
- # The reaction was by a moderator
- or isinstance(user_, Member) and any(role.id in MODERATION_ROLES for role in user_.roles)
- )
-
- return (
- # Conditions for a successful pagination:
- all((
- # Reaction is on this message
- reaction_.message.id == message.id,
- # Reaction is one of the pagination emotes
- str(reaction_.emoji) in PAGINATION_EMOJI,
- # Reaction was not made by the Bot
- user_.id != ctx.bot.user.id,
- # There were no restrictions
- no_restrictions
- ))
- )
-
paginator = cls(prefix=prefix, suffix=suffix, max_size=max_size, max_lines=max_lines,
scale_to_size=scale_to_size)
current_page = 0
@@ -303,9 +280,16 @@ class LinePaginator(Paginator):
log.trace(f"Adding reaction: {repr(emoji)}")
await message.add_reaction(emoji)
+ check = partial(
+ messages.reaction_check,
+ message_id=message.id,
+ allowed_emoji=PAGINATION_EMOJI,
+ allowed_users=(restrict_to_user.id,),
+ )
+
while True:
try:
- reaction, user = await ctx.bot.wait_for("reaction_add", timeout=timeout, check=event_check)
+ reaction, user = await ctx.bot.wait_for("reaction_add", timeout=timeout, check=check)
log.trace(f"Got reaction: {reaction}")
except asyncio.TimeoutError:
log.debug("Timed out waiting for a reaction")
diff --git a/bot/resources/tags/customchecks.md b/bot/resources/tags/customchecks.md
new file mode 100644
index 000000000..23ff7a66f
--- /dev/null
+++ b/bot/resources/tags/customchecks.md
@@ -0,0 +1,21 @@
+**Custom Command Checks in discord.py**
+
+Often you may find the need to use checks that don't exist by default in discord.py. Fortunately, discord.py provides `discord.ext.commands.check` which allows you to create you own checks like this:
+```py
+from discord.ext.commands import check, Context
+
+def in_any_channel(*channels):
+ async def predicate(ctx: Context):
+ return ctx.channel.id in channels
+ return check(predicate)
+```
+This check is to check whether the invoked command is in a given set of channels. The inner function, named `predicate` here, is used to perform the actual check on the command, and check logic should go in this function. It must be an async function, and always provides a single `commands.Context` argument which you can use to create check logic. This check function should return a boolean value indicating whether the check passed (return `True`) or failed (return `False`).
+
+The check can now be used like any other commands check as a decorator of a command, such as this:
+```py
[email protected](name="ping")
+@in_any_channel(728343273562701984)
+async def ping(ctx: Context):
+ ...
+```
+This would lock the `ping` command to only be used in the channel `728343273562701984`. If this check function fails it will raise a `CheckFailure` exception, which can be handled in your error handler.
diff --git a/bot/utils/checks.py b/bot/utils/checks.py
index 460a937d8..3d0c8a50c 100644
--- a/bot/utils/checks.py
+++ b/bot/utils/checks.py
@@ -20,8 +20,8 @@ from bot import constants
log = logging.getLogger(__name__)
-class InWhitelistCheckFailure(CheckFailure):
- """Raised when the `in_whitelist` check fails."""
+class ContextCheckFailure(CheckFailure):
+ """Raised when a context-specific check fails."""
def __init__(self, redirect_channel: Optional[int]) -> None:
self.redirect_channel = redirect_channel
@@ -36,6 +36,10 @@ class InWhitelistCheckFailure(CheckFailure):
super().__init__(error_message)
+class InWhitelistCheckFailure(ContextCheckFailure):
+ """Raised when the `in_whitelist` check fails."""
+
+
def in_whitelist_check(
ctx: Context,
channels: Container[int] = (),
diff --git a/bot/utils/function.py b/bot/utils/function.py
index 3ab32fe3c..9bc44e753 100644
--- a/bot/utils/function.py
+++ b/bot/utils/function.py
@@ -1,14 +1,23 @@
"""Utilities for interaction with functions."""
+import functools
import inspect
+import logging
+import types
import typing as t
+log = logging.getLogger(__name__)
+
Argument = t.Union[int, str]
BoundArgs = t.OrderedDict[str, t.Any]
Decorator = t.Callable[[t.Callable], t.Callable]
ArgValGetter = t.Callable[[BoundArgs], t.Any]
+class GlobalNameConflictError(Exception):
+ """Raised when there's a conflict between the globals used to resolve annotations of wrapped and its wrapper."""
+
+
def get_arg_value(name_or_pos: Argument, arguments: BoundArgs) -> t.Any:
"""
Return a value from `arguments` based on a name or position.
@@ -73,3 +82,66 @@ def get_bound_args(func: t.Callable, args: t.Tuple, kwargs: t.Dict[str, t.Any])
bound_args.apply_defaults()
return bound_args.arguments
+
+
+def update_wrapper_globals(
+ wrapper: types.FunctionType,
+ wrapped: types.FunctionType,
+ *,
+ ignored_conflict_names: t.Set[str] = frozenset(),
+) -> types.FunctionType:
+ """
+ Update globals of `wrapper` with the globals from `wrapped`.
+
+ For forwardrefs in command annotations discordpy uses the __global__ attribute of the function
+ to resolve their values, with decorators that replace the function this breaks because they have
+ their own globals.
+
+ This function creates a new function functionally identical to `wrapper`, which has the globals replaced with
+ a merge of `wrapped`s globals and the `wrapper`s globals.
+
+ An exception will be raised in case `wrapper` and `wrapped` share a global name that is used by
+ `wrapped`'s typehints and is not in `ignored_conflict_names`,
+ as this can cause incorrect objects being used by discordpy's converters.
+ """
+ annotation_global_names = (
+ ann.split(".", maxsplit=1)[0] for ann in wrapped.__annotations__.values() if isinstance(ann, str)
+ )
+ # Conflicting globals from both functions' modules that are also used in the wrapper and in wrapped's annotations.
+ shared_globals = set(wrapper.__code__.co_names) & set(annotation_global_names)
+ shared_globals &= set(wrapped.__globals__) & set(wrapper.__globals__) - ignored_conflict_names
+ if shared_globals:
+ raise GlobalNameConflictError(
+ f"wrapper and the wrapped function share the following "
+ f"global names used by annotations: {', '.join(shared_globals)}. Resolve the conflicts or add "
+ f"the name to the `ignored_conflict_names` set to suppress this error if this is intentional."
+ )
+
+ new_globals = wrapper.__globals__.copy()
+ new_globals.update((k, v) for k, v in wrapped.__globals__.items() if k not in wrapper.__code__.co_names)
+ return types.FunctionType(
+ code=wrapper.__code__,
+ globals=new_globals,
+ name=wrapper.__name__,
+ argdefs=wrapper.__defaults__,
+ closure=wrapper.__closure__,
+ )
+
+
+def command_wraps(
+ wrapped: types.FunctionType,
+ assigned: t.Sequence[str] = functools.WRAPPER_ASSIGNMENTS,
+ updated: t.Sequence[str] = functools.WRAPPER_UPDATES,
+ *,
+ ignored_conflict_names: t.Set[str] = frozenset(),
+) -> t.Callable[[types.FunctionType], types.FunctionType]:
+ """Update the decorated function to look like `wrapped` and update globals for discordpy forwardref evaluation."""
+ def decorator(wrapper: types.FunctionType) -> types.FunctionType:
+ return functools.update_wrapper(
+ update_wrapper_globals(wrapper, wrapped, ignored_conflict_names=ignored_conflict_names),
+ wrapped,
+ assigned,
+ updated,
+ )
+
+ return decorator
diff --git a/bot/utils/lock.py b/bot/utils/lock.py
index e44776340..ec6f92cd4 100644
--- a/bot/utils/lock.py
+++ b/bot/utils/lock.py
@@ -1,13 +1,15 @@
import asyncio
import inspect
import logging
+import types
from collections import defaultdict
-from functools import partial, wraps
+from functools import partial
from typing import Any, Awaitable, Callable, Hashable, Union
from weakref import WeakValueDictionary
from bot.errors import LockedResourceError
from bot.utils import function
+from bot.utils.function import command_wraps
log = logging.getLogger(__name__)
__lock_dicts = defaultdict(WeakValueDictionary)
@@ -17,6 +19,35 @@ _IdCallable = Callable[[function.BoundArgs], _IdCallableReturn]
ResourceId = Union[Hashable, _IdCallable]
+class SharedEvent:
+ """
+ Context manager managing an internal event exposed through the wait coro.
+
+ While any code is executing in this context manager, the underlying event will not be set;
+ when all of the holders finish the event will be set.
+ """
+
+ def __init__(self):
+ self._active_count = 0
+ self._event = asyncio.Event()
+ self._event.set()
+
+ def __enter__(self):
+ """Increment the count of the active holders and clear the internal event."""
+ self._active_count += 1
+ self._event.clear()
+
+ def __exit__(self, _exc_type, _exc_val, _exc_tb): # noqa: ANN001
+ """Decrement the count of the active holders; if 0 is reached set the internal event."""
+ self._active_count -= 1
+ if not self._active_count:
+ self._event.set()
+
+ async def wait(self) -> None:
+ """Wait for all active holders to exit."""
+ await self._event.wait()
+
+
def lock(
namespace: Hashable,
resource_id: ResourceId,
@@ -41,10 +72,10 @@ def lock(
If decorating a command, this decorator must go before (below) the `command` decorator.
"""
- def decorator(func: Callable) -> Callable:
+ def decorator(func: types.FunctionType) -> types.FunctionType:
name = func.__name__
- @wraps(func)
+ @command_wraps(func)
async def wrapper(*args, **kwargs) -> Any:
log.trace(f"{name}: mutually exclusive decorator called")
diff --git a/bot/utils/messages.py b/bot/utils/messages.py
index 077dd9569..2beead6af 100644
--- a/bot/utils/messages.py
+++ b/bot/utils/messages.py
@@ -3,6 +3,7 @@ import contextlib
import logging
import random
import re
+from functools import partial
from io import BytesIO
from typing import List, Optional, Sequence, Union
@@ -12,24 +13,66 @@ from discord.ext.commands import Context
import bot
from bot.constants import Emojis, MODERATION_ROLES, NEGATIVE_REPLIES
+from bot.utils import scheduling
log = logging.getLogger(__name__)
+def reaction_check(
+ reaction: discord.Reaction,
+ user: discord.abc.User,
+ *,
+ message_id: int,
+ allowed_emoji: Sequence[str],
+ allowed_users: Sequence[int],
+ allow_mods: bool = True,
+) -> bool:
+ """
+ Check if a reaction's emoji and author are allowed and the message is `message_id`.
+
+ If the user is not allowed, remove the reaction. Ignore reactions made by the bot.
+ If `allow_mods` is True, allow users with moderator roles even if they're not in `allowed_users`.
+ """
+ right_reaction = (
+ user != bot.instance.user
+ and reaction.message.id == message_id
+ and str(reaction.emoji) in allowed_emoji
+ )
+ if not right_reaction:
+ return False
+
+ is_moderator = (
+ allow_mods
+ and any(role.id in MODERATION_ROLES for role in getattr(user, "roles", []))
+ )
+
+ if user.id in allowed_users or is_moderator:
+ log.trace(f"Allowed reaction {reaction} by {user} on {reaction.message.id}.")
+ return True
+ else:
+ log.trace(f"Removing reaction {reaction} by {user} on {reaction.message.id}: disallowed user.")
+ scheduling.create_task(
+ reaction.message.remove_reaction(reaction.emoji, user),
+ HTTPException, # Suppress the HTTPException if adding the reaction fails
+ name=f"remove_reaction-{reaction}-{reaction.message.id}-{user}"
+ )
+ return False
+
+
async def wait_for_deletion(
message: discord.Message,
- user_ids: Sequence[discord.abc.Snowflake],
+ user_ids: Sequence[int],
deletion_emojis: Sequence[str] = (Emojis.trashcan,),
timeout: float = 60 * 5,
attach_emojis: bool = True,
- allow_moderation_roles: bool = True
+ allow_mods: bool = True
) -> None:
"""
Wait for up to `timeout` seconds for a reaction by any of the specified `user_ids` to delete the message.
An `attach_emojis` bool may be specified to determine whether to attach the given
`deletion_emojis` to the message in the given `context`.
- An `allow_moderation_roles` bool may also be specified to allow anyone with a role in `MODERATION_ROLES` to delete
+ An `allow_mods` bool may also be specified to allow anyone with a role in `MODERATION_ROLES` to delete
the message.
"""
if message.guild is None:
@@ -43,16 +86,13 @@ async def wait_for_deletion(
log.trace(f"Aborting wait_for_deletion: message {message.id} deleted prematurely.")
return
- def check(reaction: discord.Reaction, user: discord.Member) -> bool:
- """Check that the deletion emoji is reacted by the appropriate user."""
- return (
- reaction.message.id == message.id
- and str(reaction.emoji) in deletion_emojis
- and (
- user.id in user_ids
- or allow_moderation_roles and any(role.id in MODERATION_ROLES for role in user.roles)
- )
- )
+ check = partial(
+ reaction_check,
+ message_id=message.id,
+ allowed_emoji=deletion_emojis,
+ allowed_users=user_ids,
+ allow_mods=allow_mods,
+ )
with contextlib.suppress(asyncio.TimeoutError):
await bot.instance.wait_for('reaction_add', check=check, timeout=timeout)
@@ -141,14 +181,14 @@ def sub_clyde(username: Optional[str]) -> Optional[str]:
return username # Empty string or None
-async def send_denial(ctx: Context, reason: str) -> None:
+async def send_denial(ctx: Context, reason: str) -> discord.Message:
"""Send an embed denying the user with the given reason."""
embed = discord.Embed()
embed.colour = discord.Colour.red()
embed.title = random.choice(NEGATIVE_REPLIES)
embed.description = reason
- await ctx.send(embed=embed)
+ return await ctx.send(embed=embed)
def format_user(user: discord.abc.User) -> str:
diff --git a/bot/utils/scheduling.py b/bot/utils/scheduling.py
index 6843bae88..2dc485f24 100644
--- a/bot/utils/scheduling.py
+++ b/bot/utils/scheduling.py
@@ -161,18 +161,18 @@ class Scheduler:
self._log.error(f"Error in task #{task_id} {id(done_task)}!", exc_info=exception)
-def create_task(*args, **kwargs) -> asyncio.Task:
+def create_task(coro: t.Awaitable, *suppressed_exceptions: t.Type[Exception], **kwargs) -> asyncio.Task:
"""Wrapper for `asyncio.create_task` which logs exceptions raised in the task."""
- task = asyncio.create_task(*args, **kwargs)
- task.add_done_callback(_log_task_exception)
+ task = asyncio.create_task(coro, **kwargs)
+ task.add_done_callback(partial(_log_task_exception, suppressed_exceptions=suppressed_exceptions))
return task
-def _log_task_exception(task: asyncio.Task) -> None:
+def _log_task_exception(task: asyncio.Task, *, suppressed_exceptions: t.Tuple[t.Type[Exception]]) -> None:
"""Retrieve and log the exception raised in `task` if one exists."""
with contextlib.suppress(asyncio.CancelledError):
exception = task.exception()
# Log the exception if one exists.
- if exception:
+ if exception and not isinstance(exception, suppressed_exceptions):
log = logging.getLogger(__name__)
log.error(f"Error in task {task.get_name()} {id(task)}!", exc_info=exception)
diff --git a/config-default.yml b/config-default.yml
index 8c6e18470..46475f845 100644
--- a/config-default.yml
+++ b/config-default.yml
@@ -1,7 +1,8 @@
bot:
- prefix: "!"
- sentry_dsn: !ENV "BOT_SENTRY_DSN"
- token: !ENV "BOT_TOKEN"
+ prefix: "!"
+ sentry_dsn: !ENV "BOT_SENTRY_DSN"
+ token: !ENV "BOT_TOKEN"
+ trace_loggers: !ENV "BOT_TRACE_LOGGERS"
clean:
# Maximum number of messages to traverse for clean commands
@@ -46,6 +47,8 @@ style:
badge_partner: "<:partner:748666453242413136>"
badge_staff: "<:discord_staff:743882896498098226>"
badge_verified_bot_developer: "<:verified_bot_dev:743882897299210310>"
+ bot: "<:bot:812712599464443914>"
+ verified_bot: "<:verified_bot:811645219220750347>"
defcon_shutdown: "<:defcondisabled:470326273952972810>"
defcon_unshutdown: "<:defconenabled:470326274213150730>"
@@ -260,7 +263,8 @@ guild:
devops: 409416496733880320
domain_leads: 807415650778742785
helpers: &HELPERS_ROLE 267630620367257601
- moderators: &MODS_ROLE 267629731250176001
+ moderators: &MODS_ROLE 831776746206265384
+ mod_team: &MOD_TEAM_ROLE 267629731250176001
owners: &OWNERS_ROLE 267627879762755584
project_leads: 815701647526330398
@@ -273,13 +277,14 @@ guild:
moderation_roles:
- *ADMINS_ROLE
+ - *MOD_TEAM_ROLE
- *MODS_ROLE
- *OWNERS_ROLE
staff_roles:
- *ADMINS_ROLE
- *HELPERS_ROLE
- - *MODS_ROLE
+ - *MOD_TEAM_ROLE
- *OWNERS_ROLE
webhooks:
diff --git a/tests/README.md b/tests/README.md
index 4f62edd68..092324123 100644
--- a/tests/README.md
+++ b/tests/README.md
@@ -114,7 +114,7 @@ class BotCogTests(unittest.TestCase):
### Mocking coroutines
-By default, the `unittest.mock.Mock` and `unittest.mock.MagicMock` classes cannot mock coroutines, since the `__call__` method they provide is synchronous. In anticipation of the `AsyncMock` that will be [introduced in Python 3.8](https://docs.python.org/3.9/whatsnew/3.8.html#unittest), we have added an `AsyncMock` helper to [`helpers.py`](/tests/helpers.py). Do note that this drop-in replacement only implements an asynchronous `__call__` method, not the additional assertions that will come with the new `AsyncMock` type in Python 3.8.
+By default, the `unittest.mock.Mock` and `unittest.mock.MagicMock` classes cannot mock coroutines, since the `__call__` method they provide is synchronous. The [`AsyncMock`](https://docs.python.org/3/library/unittest.mock.html#unittest.mock.AsyncMock) that has been [introduced in Python 3.8](https://docs.python.org/3.9/whatsnew/3.8.html#unittest) is an asynchronous version of `MagicMock` that can be used anywhere a coroutine is expected.
### Special mocks for some `discord.py` types
diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py
new file mode 100644
index 000000000..bd4fb5942
--- /dev/null
+++ b/tests/bot/exts/backend/test_error_handler.py
@@ -0,0 +1,550 @@
+import unittest
+from unittest.mock import AsyncMock, MagicMock, call, patch
+
+from discord.ext.commands import errors
+
+from bot.api import ResponseCodeError
+from bot.errors import InvalidInfractedUser, LockedResourceError
+from bot.exts.backend.error_handler import ErrorHandler, setup
+from bot.exts.info.tags import Tags
+from bot.exts.moderation.silence import Silence
+from bot.utils.checks import InWhitelistCheckFailure
+from tests.helpers import MockBot, MockContext, MockGuild, MockRole
+
+
+class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for error handler functionality."""
+
+ def setUp(self):
+ self.bot = MockBot()
+ self.ctx = MockContext(bot=self.bot)
+
+ async def test_error_handler_already_handled(self):
+ """Should not do anything when error is already handled by local error handler."""
+ self.ctx.reset_mock()
+ cog = ErrorHandler(self.bot)
+ error = errors.CommandError()
+ error.handled = "foo"
+ self.assertIsNone(await cog.on_command_error(self.ctx, error))
+ self.ctx.send.assert_not_awaited()
+
+ async def test_error_handler_command_not_found_error_not_invoked_by_handler(self):
+ """Should try first (un)silence channel, when fail, try to get tag."""
+ error = errors.CommandNotFound()
+ test_cases = (
+ {
+ "try_silence_return": True,
+ "called_try_get_tag": False
+ },
+ {
+ "try_silence_return": False,
+ "called_try_get_tag": False
+ },
+ {
+ "try_silence_return": False,
+ "called_try_get_tag": True
+ }
+ )
+ cog = ErrorHandler(self.bot)
+ cog.try_silence = AsyncMock()
+ cog.try_get_tag = AsyncMock()
+
+ for case in test_cases:
+ with self.subTest(try_silence_return=case["try_silence_return"], try_get_tag=case["called_try_get_tag"]):
+ self.ctx.reset_mock()
+ cog.try_silence.reset_mock(return_value=True)
+ cog.try_get_tag.reset_mock()
+
+ cog.try_silence.return_value = case["try_silence_return"]
+ self.ctx.channel.id = 1234
+
+ self.assertIsNone(await cog.on_command_error(self.ctx, error))
+
+ if case["try_silence_return"]:
+ cog.try_get_tag.assert_not_awaited()
+ cog.try_silence.assert_awaited_once()
+ else:
+ cog.try_silence.assert_awaited_once()
+ cog.try_get_tag.assert_awaited_once()
+
+ self.ctx.send.assert_not_awaited()
+
+ async def test_error_handler_command_not_found_error_invoked_by_handler(self):
+ """Should do nothing when error is `CommandNotFound` and have attribute `invoked_from_error_handler`."""
+ ctx = MockContext(bot=self.bot, invoked_from_error_handler=True)
+
+ cog = ErrorHandler(self.bot)
+ cog.try_silence = AsyncMock()
+ cog.try_get_tag = AsyncMock()
+
+ error = errors.CommandNotFound()
+
+ self.assertIsNone(await cog.on_command_error(ctx, error))
+
+ cog.try_silence.assert_not_awaited()
+ cog.try_get_tag.assert_not_awaited()
+ self.ctx.send.assert_not_awaited()
+
+ async def test_error_handler_user_input_error(self):
+ """Should await `ErrorHandler.handle_user_input_error` when error is `UserInputError`."""
+ self.ctx.reset_mock()
+ cog = ErrorHandler(self.bot)
+ cog.handle_user_input_error = AsyncMock()
+ error = errors.UserInputError()
+ self.assertIsNone(await cog.on_command_error(self.ctx, error))
+ cog.handle_user_input_error.assert_awaited_once_with(self.ctx, error)
+
+ async def test_error_handler_check_failure(self):
+ """Should await `ErrorHandler.handle_check_failure` when error is `CheckFailure`."""
+ self.ctx.reset_mock()
+ cog = ErrorHandler(self.bot)
+ cog.handle_check_failure = AsyncMock()
+ error = errors.CheckFailure()
+ self.assertIsNone(await cog.on_command_error(self.ctx, error))
+ cog.handle_check_failure.assert_awaited_once_with(self.ctx, error)
+
+ async def test_error_handler_command_on_cooldown(self):
+ """Should send error with `ctx.send` when error is `CommandOnCooldown`."""
+ self.ctx.reset_mock()
+ cog = ErrorHandler(self.bot)
+ error = errors.CommandOnCooldown(10, 9)
+ self.assertIsNone(await cog.on_command_error(self.ctx, error))
+ self.ctx.send.assert_awaited_once_with(error)
+
+ async def test_error_handler_command_invoke_error(self):
+ """Should call `handle_api_error` or `handle_unexpected_error` depending on original error."""
+ cog = ErrorHandler(self.bot)
+ cog.handle_api_error = AsyncMock()
+ cog.handle_unexpected_error = AsyncMock()
+ test_cases = (
+ {
+ "args": (self.ctx, errors.CommandInvokeError(ResponseCodeError(AsyncMock()))),
+ "expect_mock_call": cog.handle_api_error
+ },
+ {
+ "args": (self.ctx, errors.CommandInvokeError(TypeError)),
+ "expect_mock_call": cog.handle_unexpected_error
+ },
+ {
+ "args": (self.ctx, errors.CommandInvokeError(LockedResourceError("abc", "test"))),
+ "expect_mock_call": "send"
+ },
+ {
+ "args": (self.ctx, errors.CommandInvokeError(InvalidInfractedUser(self.ctx.author))),
+ "expect_mock_call": "send"
+ }
+ )
+
+ for case in test_cases:
+ with self.subTest(args=case["args"], expect_mock_call=case["expect_mock_call"]):
+ self.ctx.send.reset_mock()
+ self.assertIsNone(await cog.on_command_error(*case["args"]))
+ if case["expect_mock_call"] == "send":
+ self.ctx.send.assert_awaited_once()
+ else:
+ case["expect_mock_call"].assert_awaited_once_with(
+ self.ctx, case["args"][1].original
+ )
+
+ async def test_error_handler_conversion_error(self):
+ """Should call `handle_api_error` or `handle_unexpected_error` depending on original error."""
+ cog = ErrorHandler(self.bot)
+ cog.handle_api_error = AsyncMock()
+ cog.handle_unexpected_error = AsyncMock()
+ cases = (
+ {
+ "error": errors.ConversionError(AsyncMock(), ResponseCodeError(AsyncMock())),
+ "mock_function_to_call": cog.handle_api_error
+ },
+ {
+ "error": errors.ConversionError(AsyncMock(), TypeError),
+ "mock_function_to_call": cog.handle_unexpected_error
+ }
+ )
+
+ for case in cases:
+ with self.subTest(**case):
+ self.assertIsNone(await cog.on_command_error(self.ctx, case["error"]))
+ case["mock_function_to_call"].assert_awaited_once_with(self.ctx, case["error"].original)
+
+ async def test_error_handler_two_other_errors(self):
+ """Should call `handle_unexpected_error` if error is `MaxConcurrencyReached` or `ExtensionError`."""
+ cog = ErrorHandler(self.bot)
+ cog.handle_unexpected_error = AsyncMock()
+ errs = (
+ errors.MaxConcurrencyReached(1, MagicMock()),
+ errors.ExtensionError(name="foo")
+ )
+
+ for err in errs:
+ with self.subTest(error=err):
+ cog.handle_unexpected_error.reset_mock()
+ self.assertIsNone(await cog.on_command_error(self.ctx, err))
+ cog.handle_unexpected_error.assert_awaited_once_with(self.ctx, err)
+
+ @patch("bot.exts.backend.error_handler.log")
+ async def test_error_handler_other_errors(self, log_mock):
+ """Should `log.debug` other errors."""
+ cog = ErrorHandler(self.bot)
+ error = errors.DisabledCommand() # Use this just as a other error
+ self.assertIsNone(await cog.on_command_error(self.ctx, error))
+ log_mock.debug.assert_called_once()
+
+
+class TrySilenceTests(unittest.IsolatedAsyncioTestCase):
+ """Test for helper functions that handle `CommandNotFound` error."""
+
+ def setUp(self):
+ self.bot = MockBot()
+ self.silence = Silence(self.bot)
+ self.bot.get_command.return_value = self.silence.silence
+ self.ctx = MockContext(bot=self.bot)
+ self.cog = ErrorHandler(self.bot)
+
+ async def test_try_silence_context_invoked_from_error_handler(self):
+ """Should set `Context.invoked_from_error_handler` to `True`."""
+ self.ctx.invoked_with = "foo"
+ await self.cog.try_silence(self.ctx)
+ self.assertTrue(hasattr(self.ctx, "invoked_from_error_handler"))
+ self.assertTrue(self.ctx.invoked_from_error_handler)
+
+ async def test_try_silence_get_command(self):
+ """Should call `get_command` with `silence`."""
+ self.ctx.invoked_with = "foo"
+ await self.cog.try_silence(self.ctx)
+ self.bot.get_command.assert_called_once_with("silence")
+
+ async def test_try_silence_no_permissions_to_run(self):
+ """Should return `False` because missing permissions."""
+ self.ctx.invoked_with = "foo"
+ self.bot.get_command.return_value.can_run = AsyncMock(return_value=False)
+ self.assertFalse(await self.cog.try_silence(self.ctx))
+
+ async def test_try_silence_no_permissions_to_run_command_error(self):
+ """Should return `False` because `CommandError` raised (no permissions)."""
+ self.ctx.invoked_with = "foo"
+ self.bot.get_command.return_value.can_run = AsyncMock(side_effect=errors.CommandError())
+ self.assertFalse(await self.cog.try_silence(self.ctx))
+
+ async def test_try_silence_silencing(self):
+ """Should run silence command with correct arguments."""
+ self.bot.get_command.return_value.can_run = AsyncMock(return_value=True)
+ test_cases = ("shh", "shhh", "shhhhhh", "shhhhhhhhhhhhhhhhhhh")
+
+ for case in test_cases:
+ with self.subTest(message=case):
+ self.ctx.reset_mock()
+ self.ctx.invoked_with = case
+ self.assertTrue(await self.cog.try_silence(self.ctx))
+ self.ctx.invoke.assert_awaited_once_with(
+ self.bot.get_command.return_value,
+ duration=min(case.count("h")*2, 15)
+ )
+
+ async def test_try_silence_unsilence(self):
+ """Should call unsilence command."""
+ self.silence.silence.can_run = AsyncMock(return_value=True)
+ test_cases = ("unshh", "unshhhhh", "unshhhhhhhhh")
+
+ for case in test_cases:
+ with self.subTest(message=case):
+ self.bot.get_command.side_effect = (self.silence.silence, self.silence.unsilence)
+ self.ctx.reset_mock()
+ self.ctx.invoked_with = case
+ self.assertTrue(await self.cog.try_silence(self.ctx))
+ self.ctx.invoke.assert_awaited_once_with(self.silence.unsilence)
+
+ async def test_try_silence_no_match(self):
+ """Should return `False` when message don't match."""
+ self.ctx.invoked_with = "foo"
+ self.assertFalse(await self.cog.try_silence(self.ctx))
+
+
+class TryGetTagTests(unittest.IsolatedAsyncioTestCase):
+ """Tests for `try_get_tag` function."""
+
+ def setUp(self):
+ self.bot = MockBot()
+ self.ctx = MockContext()
+ self.tag = Tags(self.bot)
+ self.cog = ErrorHandler(self.bot)
+ self.bot.get_command.return_value = self.tag.get_command
+
+ async def test_try_get_tag_get_command(self):
+ """Should call `Bot.get_command` with `tags get` argument."""
+ self.bot.get_command.reset_mock()
+ self.ctx.invoked_with = "foo"
+ await self.cog.try_get_tag(self.ctx)
+ self.bot.get_command.assert_called_once_with("tags get")
+
+ async def test_try_get_tag_invoked_from_error_handler(self):
+ """`self.ctx` should have `invoked_from_error_handler` `True`."""
+ self.ctx.invoked_from_error_handler = False
+ self.ctx.invoked_with = "foo"
+ await self.cog.try_get_tag(self.ctx)
+ self.assertTrue(self.ctx.invoked_from_error_handler)
+
+ async def test_try_get_tag_no_permissions(self):
+ """Test how to handle checks failing."""
+ self.tag.get_command.can_run = AsyncMock(return_value=False)
+ self.ctx.invoked_with = "foo"
+ self.assertIsNone(await self.cog.try_get_tag(self.ctx))
+
+ async def test_try_get_tag_command_error(self):
+ """Should call `on_command_error` when `CommandError` raised."""
+ err = errors.CommandError()
+ self.tag.get_command.can_run = AsyncMock(side_effect=err)
+ self.cog.on_command_error = AsyncMock()
+ self.ctx.invoked_with = "foo"
+ self.assertIsNone(await self.cog.try_get_tag(self.ctx))
+ self.cog.on_command_error.assert_awaited_once_with(self.ctx, err)
+
+ @patch("bot.exts.backend.error_handler.TagNameConverter")
+ async def test_try_get_tag_convert_success(self, tag_converter):
+ """Converting tag should successful."""
+ self.ctx.invoked_with = "foo"
+ tag_converter.convert = AsyncMock(return_value="foo")
+ self.assertIsNone(await self.cog.try_get_tag(self.ctx))
+ tag_converter.convert.assert_awaited_once_with(self.ctx, "foo")
+ self.ctx.invoke.assert_awaited_once()
+
+ @patch("bot.exts.backend.error_handler.TagNameConverter")
+ async def test_try_get_tag_convert_fail(self, tag_converter):
+ """Converting tag should raise `BadArgument`."""
+ self.ctx.reset_mock()
+ self.ctx.invoked_with = "bar"
+ tag_converter.convert = AsyncMock(side_effect=errors.BadArgument())
+ self.assertIsNone(await self.cog.try_get_tag(self.ctx))
+ self.ctx.invoke.assert_not_awaited()
+
+ async def test_try_get_tag_ctx_invoke(self):
+ """Should call `ctx.invoke` with proper args/kwargs."""
+ self.ctx.reset_mock()
+ self.ctx.invoked_with = "foo"
+ self.assertIsNone(await self.cog.try_get_tag(self.ctx))
+ self.ctx.invoke.assert_awaited_once_with(self.tag.get_command, tag_name="foo")
+
+ async def test_dont_call_suggestion_tag_sent(self):
+ """Should never call command suggestion if tag is already sent."""
+ self.ctx.invoked_with = "foo"
+ self.ctx.invoke = AsyncMock(return_value=True)
+ self.cog.send_command_suggestion = AsyncMock()
+
+ await self.cog.try_get_tag(self.ctx)
+ self.cog.send_command_suggestion.assert_not_awaited()
+
+ @patch("bot.exts.backend.error_handler.MODERATION_ROLES", new=[1234])
+ async def test_dont_call_suggestion_if_user_mod(self):
+ """Should not call command suggestion if user is a mod."""
+ self.ctx.invoked_with = "foo"
+ self.ctx.invoke = AsyncMock(return_value=False)
+ self.ctx.author.roles = [MockRole(id=1234)]
+ self.cog.send_command_suggestion = AsyncMock()
+
+ await self.cog.try_get_tag(self.ctx)
+ self.cog.send_command_suggestion.assert_not_awaited()
+
+ async def test_call_suggestion(self):
+ """Should call command suggestion if user is not a mod."""
+ self.ctx.invoked_with = "foo"
+ self.ctx.invoke = AsyncMock(return_value=False)
+ self.cog.send_command_suggestion = AsyncMock()
+
+ await self.cog.try_get_tag(self.ctx)
+ self.cog.send_command_suggestion.assert_awaited_once_with(self.ctx, "foo")
+
+
+class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase):
+ """Individual error categories handler tests."""
+
+ def setUp(self):
+ self.bot = MockBot()
+ self.ctx = MockContext(bot=self.bot)
+ self.cog = ErrorHandler(self.bot)
+
+ async def test_handle_input_error_handler_errors(self):
+ """Should handle each error probably."""
+ test_cases = (
+ {
+ "error": errors.MissingRequiredArgument(MagicMock()),
+ "call_prepared": True
+ },
+ {
+ "error": errors.TooManyArguments(),
+ "call_prepared": True
+ },
+ {
+ "error": errors.BadArgument(),
+ "call_prepared": True
+ },
+ {
+ "error": errors.BadUnionArgument(MagicMock(), MagicMock(), MagicMock()),
+ "call_prepared": True
+ },
+ {
+ "error": errors.ArgumentParsingError(),
+ "call_prepared": False
+ },
+ {
+ "error": errors.UserInputError(),
+ "call_prepared": True
+ }
+ )
+
+ for case in test_cases:
+ with self.subTest(error=case["error"], call_prepared=case["call_prepared"]):
+ self.ctx.reset_mock()
+ self.assertIsNone(await self.cog.handle_user_input_error(self.ctx, case["error"]))
+ self.ctx.send.assert_awaited_once()
+ if case["call_prepared"]:
+ self.ctx.send_help.assert_awaited_once()
+ else:
+ self.ctx.send_help.assert_not_awaited()
+
+ async def test_handle_check_failure_errors(self):
+ """Should await `ctx.send` when error is check failure."""
+ test_cases = (
+ {
+ "error": errors.BotMissingPermissions(MagicMock()),
+ "call_ctx_send": True
+ },
+ {
+ "error": errors.BotMissingRole(MagicMock()),
+ "call_ctx_send": True
+ },
+ {
+ "error": errors.BotMissingAnyRole(MagicMock()),
+ "call_ctx_send": True
+ },
+ {
+ "error": errors.NoPrivateMessage(),
+ "call_ctx_send": True
+ },
+ {
+ "error": InWhitelistCheckFailure(1234),
+ "call_ctx_send": True
+ },
+ {
+ "error": ResponseCodeError(MagicMock()),
+ "call_ctx_send": False
+ }
+ )
+
+ for case in test_cases:
+ with self.subTest(error=case["error"], call_ctx_send=case["call_ctx_send"]):
+ self.ctx.reset_mock()
+ await self.cog.handle_check_failure(self.ctx, case["error"])
+ if case["call_ctx_send"]:
+ self.ctx.send.assert_awaited_once()
+ else:
+ self.ctx.send.assert_not_awaited()
+
+ @patch("bot.exts.backend.error_handler.log")
+ async def test_handle_api_error(self, log_mock):
+ """Should `ctx.send` on HTTP error codes, `log.debug|warning` depends on code."""
+ test_cases = (
+ {
+ "error": ResponseCodeError(AsyncMock(status=400)),
+ "log_level": "debug"
+ },
+ {
+ "error": ResponseCodeError(AsyncMock(status=404)),
+ "log_level": "debug"
+ },
+ {
+ "error": ResponseCodeError(AsyncMock(status=550)),
+ "log_level": "warning"
+ },
+ {
+ "error": ResponseCodeError(AsyncMock(status=1000)),
+ "log_level": "warning"
+ }
+ )
+
+ for case in test_cases:
+ with self.subTest(error=case["error"], log_level=case["log_level"]):
+ self.ctx.reset_mock()
+ log_mock.reset_mock()
+ await self.cog.handle_api_error(self.ctx, case["error"])
+ self.ctx.send.assert_awaited_once()
+ if case["log_level"] == "warning":
+ log_mock.warning.assert_called_once()
+ else:
+ log_mock.debug.assert_called_once()
+
+ @patch("bot.exts.backend.error_handler.push_scope")
+ @patch("bot.exts.backend.error_handler.log")
+ async def test_handle_unexpected_error(self, log_mock, push_scope_mock):
+ """Should `ctx.send` this error, error log this and sent to Sentry."""
+ for case in (None, MockGuild()):
+ with self.subTest(guild=case):
+ self.ctx.reset_mock()
+ log_mock.reset_mock()
+ push_scope_mock.reset_mock()
+
+ self.ctx.guild = case
+ await self.cog.handle_unexpected_error(self.ctx, errors.CommandError())
+
+ self.ctx.send.assert_awaited_once()
+ log_mock.error.assert_called_once()
+ push_scope_mock.assert_called_once()
+
+ set_tag_calls = [
+ call("command", self.ctx.command.qualified_name),
+ call("message_id", self.ctx.message.id),
+ call("channel_id", self.ctx.channel.id),
+ ]
+ set_extra_calls = [
+ call("full_message", self.ctx.message.content)
+ ]
+ if case:
+ url = (
+ f"https://discordapp.com/channels/"
+ f"{self.ctx.guild.id}/{self.ctx.channel.id}/{self.ctx.message.id}"
+ )
+ set_extra_calls.append(call("jump_to", url))
+
+ push_scope_mock.set_tag.has_calls(set_tag_calls)
+ push_scope_mock.set_extra.has_calls(set_extra_calls)
+
+
+class OtherErrorHandlerTests(unittest.IsolatedAsyncioTestCase):
+ """Other `ErrorHandler` tests."""
+
+ def setUp(self):
+ self.bot = MockBot()
+ self.ctx = MockContext()
+
+ async def test_get_help_command_command_specified(self):
+ """Should return coroutine of help command of specified command."""
+ self.ctx.command = "foo"
+ result = ErrorHandler.get_help_command(self.ctx)
+ expected = self.ctx.send_help("foo")
+ self.assertEqual(result.__qualname__, expected.__qualname__)
+ self.assertEqual(result.cr_frame.f_locals, expected.cr_frame.f_locals)
+
+ # Await coroutines to avoid warnings
+ await result
+ await expected
+
+ async def test_get_help_command_no_command_specified(self):
+ """Should return coroutine of help command."""
+ self.ctx.command = None
+ result = ErrorHandler.get_help_command(self.ctx)
+ expected = self.ctx.send_help()
+ self.assertEqual(result.__qualname__, expected.__qualname__)
+ self.assertEqual(result.cr_frame.f_locals, expected.cr_frame.f_locals)
+
+ # Await coroutines to avoid warnings
+ await result
+ await expected
+
+
+class ErrorHandlerSetupTests(unittest.TestCase):
+ """Tests for `ErrorHandler` `setup` function."""
+
+ def test_setup(self):
+ """Should call `bot.add_cog` with `ErrorHandler`."""
+ bot = MockBot()
+ setup(bot)
+ bot.add_cog.assert_called_once()
diff --git a/tests/bot/exts/info/doc/__init__.py b/tests/bot/exts/info/doc/__init__.py
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/tests/bot/exts/info/doc/__init__.py
diff --git a/tests/bot/exts/info/doc/test_parsing.py b/tests/bot/exts/info/doc/test_parsing.py
new file mode 100644
index 000000000..1663d8491
--- /dev/null
+++ b/tests/bot/exts/info/doc/test_parsing.py
@@ -0,0 +1,66 @@
+from unittest import TestCase
+
+from bot.exts.info.doc import _parsing as parsing
+
+
+class SignatureSplitter(TestCase):
+
+ def test_basic_split(self):
+ test_cases = (
+ ("0,0,0", ["0", "0", "0"]),
+ ("0,a=0,a=0", ["0", "a=0", "a=0"]),
+ )
+ self._run_tests(test_cases)
+
+ def test_commas_ignored_in_brackets(self):
+ test_cases = (
+ ("0,[0,0],0,[0,0],0", ["0", "[0,0]", "0", "[0,0]", "0"]),
+ ("(0,),0,(0,(0,),0),0", ["(0,)", "0", "(0,(0,),0)", "0"]),
+ )
+ self._run_tests(test_cases)
+
+ def test_mixed_brackets(self):
+ tests_cases = (
+ ("[0,{0},0],0,{0:0},0", ["[0,{0},0]", "0", "{0:0}", "0"]),
+ ("([0],0,0),0,(0,0),0", ["([0],0,0)", "0", "(0,0)", "0"]),
+ ("([(0,),(0,)],0),0", ["([(0,),(0,)],0)", "0"]),
+ )
+ self._run_tests(tests_cases)
+
+ def test_string_contents_ignored(self):
+ test_cases = (
+ ("'0,0',0,',',0", ["'0,0'", "0", "','", "0"]),
+ ("0,[']',0],0", ["0", "[']',0]", "0"]),
+ ("{0,0,'}}',0,'{'},0", ["{0,0,'}}',0,'{'}", "0"]),
+ )
+ self._run_tests(test_cases)
+
+ def test_mixed_quotes(self):
+ test_cases = (
+ ("\"0',0',\",'0,0',0", ["\"0',0',\"", "'0,0'", "0"]),
+ ("\",',\",'\",',0", ["\",',\"", "'\",'", "0"]),
+ )
+ self._run_tests(test_cases)
+
+ def test_quote_escaped(self):
+ test_cases = (
+ (r"'\',','\\',0", [r"'\','", r"'\\'", "0"]),
+ (r"'0\',0\\\'\\',0", [r"'0\',0\\\'\\'", "0"]),
+ )
+ self._run_tests(test_cases)
+
+ def test_real_signatures(self):
+ test_cases = (
+ ("start, stop[, step]", ["start", " stop[, step]"]),
+ ("object=b'', encoding='utf-8', errors='strict'", ["object=b''", " encoding='utf-8'", " errors='strict'"]),
+ (
+ "typename, field_names, *, rename=False, defaults=None, module=None",
+ ["typename", " field_names", " *", " rename=False", " defaults=None", " module=None"]
+ ),
+ )
+ self._run_tests(test_cases)
+
+ def _run_tests(self, test_cases):
+ for input_string, expected_output in test_cases:
+ with self.subTest(input_string=input_string):
+ self.assertEqual(list(parsing._split_parameters(input_string)), expected_output)
diff --git a/tests/bot/exts/info/test_information.py b/tests/bot/exts/info/test_information.py
index a996ce477..770660fe3 100644
--- a/tests/bot/exts/info/test_information.py
+++ b/tests/bot/exts/info/test_information.py
@@ -281,6 +281,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase):
"""The embed should use the string representation of the user if they don't have a nick."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1))
user = helpers.MockMember()
+ user.public_flags = unittest.mock.MagicMock(verified_bot=False)
user.nick = None
user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock")
user.colour = 0
@@ -297,6 +298,7 @@ class UserEmbedTests(unittest.IsolatedAsyncioTestCase):
"""The embed should use the nick if it's available."""
ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1))
user = helpers.MockMember()
+ user.public_flags = unittest.mock.MagicMock(verified_bot=False)
user.nick = "Cat lover"
user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock")
user.colour = 0
diff --git a/tests/bot/test_converters.py b/tests/bot/test_converters.py
index c42111f3f..4af84dde5 100644
--- a/tests/bot/test_converters.py
+++ b/tests/bot/test_converters.py
@@ -10,9 +10,9 @@ from bot.converters import (
Duration,
HushDurationConverter,
ISODateTime,
+ PackageName,
TagContentConverter,
TagNameConverter,
- ValidPythonIdentifier,
)
@@ -78,24 +78,23 @@ class ConverterTests(unittest.IsolatedAsyncioTestCase):
with self.assertRaisesRegex(BadArgument, re.escape(exception_message)):
await TagNameConverter.convert(self.context, invalid_name)
- async def test_valid_python_identifier_for_valid(self):
- """ValidPythonIdentifier returns valid identifiers unchanged."""
- test_values = ('foo', 'lemon')
+ async def test_package_name_for_valid(self):
+ """PackageName returns valid package names unchanged."""
+ test_values = ('foo', 'le_mon', 'num83r')
for name in test_values:
with self.subTest(identifier=name):
- conversion = await ValidPythonIdentifier.convert(self.context, name)
+ conversion = await PackageName.convert(self.context, name)
self.assertEqual(name, conversion)
- async def test_valid_python_identifier_for_invalid(self):
- """ValidPythonIdentifier raises the proper exception for invalid identifiers."""
- test_values = ('nested.stuff', '#####')
+ async def test_package_name_for_invalid(self):
+ """PackageName raises the proper exception for invalid package names."""
+ test_values = ('text_with_a_dot.', 'UpperCaseName', 'dashed-name')
for name in test_values:
with self.subTest(identifier=name):
- exception_message = f'`{name}` is not a valid Python identifier'
- with self.assertRaisesRegex(BadArgument, re.escape(exception_message)):
- await ValidPythonIdentifier.convert(self.context, name)
+ with self.assertRaises(BadArgument):
+ await PackageName.convert(self.context, name)
async def test_duration_converter_for_valid(self):
"""Duration returns the correct `datetime` for valid duration strings."""
diff --git a/tests/helpers.py b/tests/helpers.py
index 496363ae3..e3dc5fe5b 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -385,6 +385,7 @@ message_instance = discord.Message(state=state, channel=channel, data=message_da
# Create a Context instance to get a realistic MagicMock of `discord.ext.commands.Context`
context_instance = Context(message=unittest.mock.MagicMock(), prefix=unittest.mock.MagicMock())
+context_instance.invoked_from_error_handler = None
class MockContext(CustomMockMixin, unittest.mock.MagicMock):
@@ -402,6 +403,7 @@ class MockContext(CustomMockMixin, unittest.mock.MagicMock):
self.guild = kwargs.get('guild', MockGuild())
self.author = kwargs.get('author', MockMember())
self.channel = kwargs.get('channel', MockTextChannel())
+ self.invoked_from_error_handler = kwargs.get('invoked_from_error_handler', False)
attachment_instance = discord.Attachment(data=unittest.mock.MagicMock(id=1), state=unittest.mock.MagicMock())