aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.coveragerc5
-rw-r--r--.gitignore7
-rw-r--r--Pipfile6
-rw-r--r--Pipfile.lock124
-rw-r--r--azure-pipelines.yml11
-rw-r--r--bot/__main__.py5
-rw-r--r--bot/cogs/alias.py13
-rw-r--r--bot/cogs/antimalware.py56
-rw-r--r--bot/cogs/antispam.py14
-rw-r--r--bot/cogs/clean.py2
-rw-r--r--bot/cogs/cogs.py298
-rw-r--r--bot/cogs/defcon.py163
-rw-r--r--bot/cogs/doc.py14
-rw-r--r--bot/cogs/extensions.py236
-rw-r--r--bot/cogs/filtering.py47
-rw-r--r--bot/cogs/free.py25
-rw-r--r--bot/cogs/help.py48
-rw-r--r--bot/cogs/information.py312
-rw-r--r--bot/cogs/logging.py6
-rw-r--r--bot/cogs/moderation.py1172
-rw-r--r--bot/cogs/moderation/__init__.py25
-rw-r--r--bot/cogs/moderation/infractions.py617
-rw-r--r--bot/cogs/moderation/management.py272
-rw-r--r--bot/cogs/moderation/modlog.py (renamed from bot/cogs/modlog.py)149
-rw-r--r--bot/cogs/moderation/superstarify.py (renamed from bot/cogs/superstarify/__init__.py)105
-rw-r--r--bot/cogs/moderation/utils.py172
-rw-r--r--bot/cogs/off_topic_names.py62
-rw-r--r--bot/cogs/reddit.py232
-rw-r--r--bot/cogs/reminders.py19
-rw-r--r--bot/cogs/site.py10
-rw-r--r--bot/cogs/snekbox.py25
-rw-r--r--bot/cogs/superstarify/stars.py87
-rw-r--r--bot/cogs/sync/cog.py6
-rw-r--r--bot/cogs/token_remover.py12
-rw-r--r--bot/cogs/utils.py132
-rw-r--r--bot/cogs/verification.py40
-rw-r--r--bot/cogs/watchchannels/bigbrother.py30
-rw-r--r--bot/cogs/watchchannels/talentpool.py19
-rw-r--r--bot/cogs/watchchannels/watchchannel.py2
-rw-r--r--bot/constants.py31
-rw-r--r--bot/converters.py44
-rw-r--r--bot/decorators.py69
-rw-r--r--bot/resources/stars.json160
-rw-r--r--bot/utils/checks.py54
-rw-r--r--bot/utils/moderation.py72
-rw-r--r--bot/utils/time.py19
-rw-r--r--config-default.yml43
-rw-r--r--docker-compose.yml2
-rw-r--r--tests/README.md213
-rw-r--r--tests/__init__.py5
-rw-r--r--tests/base.py67
-rw-r--r--tests/bot/__init__.py (renamed from tests/cogs/__init__.py)0
-rw-r--r--tests/bot/cogs/__init__.py (renamed from tests/cogs/sync/__init__.py)0
-rw-r--r--tests/bot/cogs/sync/__init__.py (renamed from tests/rules/__init__.py)0
-rw-r--r--tests/bot/cogs/sync/test_roles.py126
-rw-r--r--tests/bot/cogs/sync/test_users.py84
-rw-r--r--tests/bot/cogs/test_antispam.py35
-rw-r--r--tests/bot/cogs/test_information.py582
-rw-r--r--tests/bot/cogs/test_security.py59
-rw-r--r--tests/bot/cogs/test_token_remover.py135
-rw-r--r--tests/bot/patches/__init__.py (renamed from tests/utils/__init__.py)0
-rw-r--r--tests/bot/resources/__init__.py0
-rw-r--r--tests/bot/resources/test_resources.py17
-rw-r--r--tests/bot/rules/__init__.py0
-rw-r--r--tests/bot/rules/test_attachments.py52
-rw-r--r--tests/bot/test_api.py134
-rw-r--r--tests/bot/test_constants.py26
-rw-r--r--tests/bot/test_converters.py273
-rw-r--r--tests/bot/test_pagination.py (renamed from tests/test_pagination.py)13
-rw-r--r--tests/bot/test_utils.py52
-rw-r--r--tests/bot/utils/__init__.py0
-rw-r--r--tests/bot/utils/test_checks.py51
-rw-r--r--tests/cogs/sync/test_roles.py103
-rw-r--r--tests/cogs/sync/test_users.py69
-rw-r--r--tests/cogs/test_antispam.py30
-rw-r--r--tests/cogs/test_information.py163
-rw-r--r--tests/cogs/test_security.py54
-rw-r--r--tests/cogs/test_token_remover.py133
-rw-r--r--tests/conftest.py32
-rw-r--r--tests/helpers.py384
-rw-r--r--tests/rules/test_attachments.py52
-rw-r--r--tests/test_api.py106
-rw-r--r--tests/test_base.py91
-rw-r--r--tests/test_constants.py23
-rw-r--r--tests/test_converters.py186
-rw-r--r--tests/test_helpers.py428
-rw-r--r--tests/test_resources.py18
-rw-r--r--tests/utils/test_checks.py66
-rw-r--r--tests/utils/test_time.py62
89 files changed, 5522 insertions, 3476 deletions
diff --git a/.coveragerc b/.coveragerc
new file mode 100644
index 000000000..d572bd705
--- /dev/null
+++ b/.coveragerc
@@ -0,0 +1,5 @@
+[run]
+branch = true
+source =
+ bot
+ tests
diff --git a/.gitignore b/.gitignore
index 261fa179f..fb3156ab1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -114,5 +114,8 @@ log.*
# Custom user configuration
config.yml
-# JUnit XML reports from pytest
-junit.xml
+# xmlrunner unittest XML reports
+TEST-**.xml
+
+# Mac OS .DS_Store, which is a file that stores custom attributes of its containing folder
+.DS_Store
diff --git a/Pipfile b/Pipfile
index 82847b23f..48d839fc3 100644
--- a/Pipfile
+++ b/Pipfile
@@ -21,6 +21,7 @@ more_itertools = "~=7.2"
urllib3 = ">=1.24.2,<1.25"
[dev-packages]
+coverage = "~=4.5"
flake8 = "~=3.7"
flake8-annotations = "~=1.1"
flake8-bugbear = "~=19.8"
@@ -31,9 +32,8 @@ flake8-tidy-imports = "~=2.0"
flake8-todo = "~=0.7"
pre-commit = "~=1.18"
safety = "~=1.8"
+unittest-xml-reporting = "~=2.5"
dodgy = "~=0.1"
-pytest = "~=5.1"
-pytest-cov = "~=2.7"
[requires]
python_version = "3.7"
@@ -44,3 +44,5 @@ lint = "python -m flake8"
precommit = "pre-commit install"
build = "docker build -t pythondiscord/bot:latest -f Dockerfile ."
push = "docker push pythondiscord/bot:latest"
+test = "coverage run -m unittest"
+report = "coverage report"
diff --git a/Pipfile.lock b/Pipfile.lock
index 4e6b4eaf8..95955ff89 100644
--- a/Pipfile.lock
+++ b/Pipfile.lock
@@ -1,7 +1,7 @@
{
"_meta": {
"hash": {
- "sha256": "c2537cc3c5b0886d0b38f9b48f4f4b93e1e74d925454aa71a2189bddedadde42"
+ "sha256": "c27d699b4aeeed204dee41f924f682ae2a670add8549a8826e58776594370582"
},
"pipfile-spec": 6,
"requires": {
@@ -18,11 +18,11 @@
"default": {
"aio-pika": {
"hashes": [
- "sha256:29f27a8092169924c9eefb0c5e428d216706618dc9caa75ddb7759638e16cf26",
- "sha256:4f77ba9b6e7bc27fc88c49638bc3657ae5d4a2539e17fa0c2b25b370547b1b50"
+ "sha256:1dcec3e3e3309e277511dc0d7d157676d0165c174a6a745673fc9cf0510db8f0",
+ "sha256:dd5a23ca26a4872ee73bd107e4c545bace572cdec2a574aeb61f4062c7774b2a"
],
"index": "pypi",
- "version": "==6.1.2"
+ "version": "==6.1.3"
},
"aiodns": {
"hashes": [
@@ -83,10 +83,10 @@
},
"attrs": {
"hashes": [
- "sha256:69c0dbf2ed392de1cb5ec704444b08a5ef81680a61cb899dc08127123af36a79",
- "sha256:f0b870f674851ecbfbbbd364d6b5cbdff9dcedbc7f3f5e18a6891057f21fe399"
+ "sha256:ec20e7a4825331c1b5ebf261d111e16fa9612c1f7a5e1f884f12bd53a664dfd2",
+ "sha256:f913492e1663d3c36f502e5e9ba6cd13cf19d7fab50aa13239e420fef95e1396"
],
- "version": "==19.1.0"
+ "version": "==19.2.0"
},
"babel": {
"hashes": [
@@ -97,11 +97,11 @@
},
"beautifulsoup4": {
"hashes": [
- "sha256:05668158c7b85b791c5abde53e50265e16f98ad601c402ba44d70f96c4159612",
- "sha256:25288c9e176f354bf277c0a10aa96c782a6a18a17122dba2e8cec4a97e03343b",
- "sha256:f040590be10520f2ea4c2ae8c3dae441c7cfff5308ec9d58a0ec0c1b8f81d469"
+ "sha256:5279c36b4b2ec2cb4298d723791467e3000e5384a43ea0cdf5d45207c7e97169",
+ "sha256:6135db2ba678168c07950f9a16c4031822c6f4aec75a65e0a97bc5ca09789931",
+ "sha256:dcdef580e18a76d54002088602eba453eec38ebbcafafeaabd8cab12b6155d57"
],
- "version": "==4.8.0"
+ "version": "==4.8.1"
},
"certifi": {
"hashes": [
@@ -150,13 +150,6 @@
],
"version": "==3.0.4"
},
- "colorama": {
- "hashes": [
- "sha256:05eed71e2e327246ad6b38c540c4a3117230b19679b875190486ddd2d721422d",
- "sha256:f8ac84de7840f5b9c4e3347b3c1eaa50f7e49c2b07596221daec5edaabbd7c48"
- ],
- "version": "==0.4.1"
- },
"deepdiff": {
"hashes": [
"sha256:1123762580af0904621136d117c8397392a244d3ff0fa0a50de57a7939582476",
@@ -204,10 +197,10 @@
},
"jinja2": {
"hashes": [
- "sha256:065c4f02ebe7f7cf559e49ee5a95fb800a9e4528727aec6f24402a5374c65013",
- "sha256:14dd6caf1527abb21f08f86c784eac40853ba93edb79552aa1e4b8aef1b61c7b"
+ "sha256:74320bb91f31270f9551d46522e33af46a80c3d619f4a4bf42b3164d30b5911f",
+ "sha256:9fe95f19286cfefaa917656583d020be14e7859c6b0252588391e47db34527de"
],
- "version": "==2.10.1"
+ "version": "==2.10.3"
},
"jsonpickle": {
"hashes": [
@@ -407,10 +400,10 @@
},
"pytz": {
"hashes": [
- "sha256:26c0b32e437e54a18161324a2fca3c4b9846b74a8dccddd843113109e1116b32",
- "sha256:c894d57500a4cd2d5c71114aaab77dbab5eabd9022308ce5ac9bb93a60a6f0c7"
+ "sha256:1c557d7d0e871de1f5ccd5833f60fb2550652da6be2693c1e02300743d21500d",
+ "sha256:b02c06db6cf09c12dd25137e563b31700d3b80fcc4ad23abb7a315f2789819be"
],
- "version": "==2019.2"
+ "version": "==2019.3"
},
"pyyaml": {
"hashes": [
@@ -448,9 +441,10 @@
},
"snowballstemmer": {
"hashes": [
- "sha256:713e53b79cbcf97bc5245a06080a33d54a77e7cce2f789c835a143bcdb5c033e"
+ "sha256:209f257d7533fdb3cb73bdbd24f436239ca3b2fa67d56f6ff88e86be08cc5ef0",
+ "sha256:df3bac3df4c2c01363f3dd2cfa78cce2840a79b9f1c2d2de9ce8d31683992f52"
],
- "version": "==1.9.1"
+ "version": "==2.0.0"
},
"soupsieve": {
"hashes": [
@@ -568,19 +562,12 @@
],
"version": "==1.3.0"
},
- "atomicwrites": {
- "hashes": [
- "sha256:03472c30eb2c5d1ba9227e4c2ca66ab8287fbfbbda3888aa93dc2e28fc6811b4",
- "sha256:75a9445bac02d8d058d5e1fe689654ba5a6556a1dfd8ce6ec55a0ed79866cfa6"
- ],
- "version": "==1.3.0"
- },
"attrs": {
"hashes": [
- "sha256:69c0dbf2ed392de1cb5ec704444b08a5ef81680a61cb899dc08127123af36a79",
- "sha256:f0b870f674851ecbfbbbd364d6b5cbdff9dcedbc7f3f5e18a6891057f21fe399"
+ "sha256:ec20e7a4825331c1b5ebf261d111e16fa9612c1f7a5e1f884f12bd53a664dfd2",
+ "sha256:f913492e1663d3c36f502e5e9ba6cd13cf19d7fab50aa13239e420fef95e1396"
],
- "version": "==19.1.0"
+ "version": "==19.2.0"
},
"certifi": {
"hashes": [
@@ -610,13 +597,6 @@
],
"version": "==7.0"
},
- "colorama": {
- "hashes": [
- "sha256:05eed71e2e327246ad6b38c540c4a3117230b19679b875190486ddd2d721422d",
- "sha256:f8ac84de7840f5b9c4e3347b3c1eaa50f7e49c2b07596221daec5edaabbd7c48"
- ],
- "version": "==0.4.1"
- },
"coverage": {
"hashes": [
"sha256:08907593569fe59baca0bf152c43f3863201efb6113ecb38ce7e97ce339805a6",
@@ -652,6 +632,7 @@
"sha256:fa964bae817babece5aa2e8c1af841bebb6d0b9add8e637548809d040443fee0",
"sha256:ff37757e068ae606659c28c3bd0d923f9d29a85de79bf25b2b34b148473b5025"
],
+ "index": "pypi",
"version": "==4.5.4"
},
"dodgy": {
@@ -701,11 +682,11 @@
},
"flake8-docstrings": {
"hashes": [
- "sha256:1666dd069c9c457ee57e80af3c1a6b37b00cc1801c6fde88e455131bb2e186cd",
- "sha256:9c0db5a79a1affd70fdf53b8765c8a26bf968e59e0252d7f2fc546b41c0cda06"
+ "sha256:3d5a31c7ec6b7367ea6506a87ec293b94a0a46c0bce2bb4975b7f1d09b6f3717",
+ "sha256:a256ba91bc52307bef1de59e2a009c3cf61c3d0952dbe035d6ff7208940c2edc"
],
"index": "pypi",
- "version": "==1.4.0"
+ "version": "==1.5.0"
},
"flake8-import-order": {
"hashes": [
@@ -757,7 +738,6 @@
"sha256:aa18d7378b00b40847790e7c27e11673d7fed219354109d0e7b9e5b25dc3ad26",
"sha256:d5f18a79777f3aa179c145737780282e27b508fc8fd688cb17c7a813e8bd39af"
],
- "markers": "python_version < '3.8'",
"version": "==0.23"
},
"mccabe": {
@@ -788,13 +768,6 @@
],
"version": "==19.2"
},
- "pluggy": {
- "hashes": [
- "sha256:0db4b7601aae1d35b4a033282da476845aa19185c1e6964b25cf324b5e4ec3e6",
- "sha256:fa5fa1622fa6dd5c030e9cad086fa19ef6a0cf6d7a2d12318e10cb49d6d68f34"
- ],
- "version": "==0.13.0"
- },
"pre-commit": {
"hashes": [
"sha256:1d3c0587bda7c4e537a46c27f2c84aa006acc18facf9970bf947df596ce91f3f",
@@ -803,13 +776,6 @@
"index": "pypi",
"version": "==1.18.3"
},
- "py": {
- "hashes": [
- "sha256:64f65755aee5b381cea27766a3a147c3f15b9b6b9ac88676de66ba2ae36793fa",
- "sha256:dc639b046a6e2cff5bbe40194ad65936d6ba360b52b3c3fe1d08a82dd50b5e53"
- ],
- "version": "==1.8.0"
- },
"pycodestyle": {
"hashes": [
"sha256:95a2219d12372f05704562a14ec30bc76b05a5b297b21a5dfe3f6fac3491ae56",
@@ -838,22 +804,6 @@
],
"version": "==2.4.2"
},
- "pytest": {
- "hashes": [
- "sha256:813b99704b22c7d377bbd756ebe56c35252bb710937b46f207100e843440b3c2",
- "sha256:cc6620b96bc667a0c8d4fa592a8c9c94178a1bd6cc799dbb057dfd9286d31a31"
- ],
- "index": "pypi",
- "version": "==5.1.3"
- },
- "pytest-cov": {
- "hashes": [
- "sha256:2b097cde81a302e1047331b48cadacf23577e431b61e9c6f49a1170bbe3d3da6",
- "sha256:e00ea4fdde970725482f1f35630d12f074e121a23801aabf2ae154ec6bdd343a"
- ],
- "index": "pypi",
- "version": "==2.7.1"
- },
"pyyaml": {
"hashes": [
"sha256:0113bc0ec2ad727182326b61326afa3d1d8280ae1122493553fd6f4397f33df9",
@@ -898,9 +848,10 @@
},
"snowballstemmer": {
"hashes": [
- "sha256:713e53b79cbcf97bc5245a06080a33d54a77e7cce2f789c835a143bcdb5c033e"
+ "sha256:209f257d7533fdb3cb73bdbd24f436239ca3b2fa67d56f6ff88e86be08cc5ef0",
+ "sha256:df3bac3df4c2c01363f3dd2cfa78cce2840a79b9f1c2d2de9ce8d31683992f52"
],
- "version": "==1.9.1"
+ "version": "==2.0.0"
},
"toml": {
"hashes": [
@@ -929,6 +880,14 @@
],
"version": "==1.4.0"
},
+ "unittest-xml-reporting": {
+ "hashes": [
+ "sha256:140982e4b58e4052d9ecb775525b246a96bfc1fc26097806e05ea06e9166dd6c",
+ "sha256:d1fbc7a1b6c6680ccfe75b5e9701e5431c646970de049e687b4bb35ba4325d72"
+ ],
+ "index": "pypi",
+ "version": "==2.5.1"
+ },
"urllib3": {
"hashes": [
"sha256:2393a695cd12afedd0dcb26fe5d50d0cf248e5a66f75dbd89a3d4eb333a61af4",
@@ -944,13 +903,6 @@
],
"version": "==16.7.5"
},
- "wcwidth": {
- "hashes": [
- "sha256:3df37372226d6e63e1b1e1eda15c594bca98a22d33a23832a90998faa96bc65e",
- "sha256:f4ebe71925af7b40a864553f761ed559b43544f8f71746c2d756c7fe788ade7c"
- ],
- "version": "==0.1.7"
- },
"zipp": {
"hashes": [
"sha256:3718b1cbcd963c7d4c5511a8240812904164b7f381b647143a89d3b98f9bcd8e",
diff --git a/azure-pipelines.yml b/azure-pipelines.yml
index c22bac089..da3b06201 100644
--- a/azure-pipelines.yml
+++ b/azure-pipelines.yml
@@ -30,9 +30,12 @@ jobs:
- script: python -m flake8
displayName: 'Run linter'
- - script: BOT_API_KEY=foo BOT_TOKEN=bar WOLFRAM_API_KEY=baz python -m pytest --junitxml=junit.xml --cov=bot --cov-branch --cov-report=term --cov-report=xml tests
+ - script: BOT_API_KEY=foo BOT_TOKEN=bar WOLFRAM_API_KEY=baz coverage run -m xmlrunner
displayName: Run tests
+ - script: coverage report -m && coverage xml -o coverage.xml
+ displayName: Generate test coverage report
+
- task: PublishCodeCoverageResults@1
displayName: 'Publish Coverage Results'
condition: succeededOrFailed()
@@ -41,11 +44,11 @@ jobs:
summaryFileLocation: coverage.xml
- task: PublishTestResults@2
- displayName: 'Publish Test Results'
condition: succeededOrFailed()
+ displayName: 'Publish Test Results'
inputs:
- testResultsFiles: junit.xml
- testRunTitle: 'Bot Test results'
+ testResultsFiles: '**/TEST-*.xml'
+ testRunTitle: 'Bot Test Results'
- job: build
displayName: 'Build & Push Container'
diff --git a/bot/__main__.py b/bot/__main__.py
index f25693734..f352cd60e 100644
--- a/bot/__main__.py
+++ b/bot/__main__.py
@@ -36,14 +36,14 @@ log.addHandler(APILoggingHandler(bot.api_client))
bot.load_extension("bot.cogs.error_handler")
bot.load_extension("bot.cogs.filtering")
bot.load_extension("bot.cogs.logging")
-bot.load_extension("bot.cogs.modlog")
bot.load_extension("bot.cogs.security")
# Commands, etc
+bot.load_extension("bot.cogs.antimalware")
bot.load_extension("bot.cogs.antispam")
bot.load_extension("bot.cogs.bot")
bot.load_extension("bot.cogs.clean")
-bot.load_extension("bot.cogs.cogs")
+bot.load_extension("bot.cogs.extensions")
bot.load_extension("bot.cogs.help")
# Only load this in production
@@ -64,7 +64,6 @@ bot.load_extension("bot.cogs.reddit")
bot.load_extension("bot.cogs.reminders")
bot.load_extension("bot.cogs.site")
bot.load_extension("bot.cogs.snekbox")
-bot.load_extension("bot.cogs.superstarify")
bot.load_extension("bot.cogs.sync")
bot.load_extension("bot.cogs.tags")
bot.load_extension("bot.cogs.token_remover")
diff --git a/bot/cogs/alias.py b/bot/cogs/alias.py
index 0f49a400c..5190c559b 100644
--- a/bot/cogs/alias.py
+++ b/bot/cogs/alias.py
@@ -5,6 +5,7 @@ from typing import Union
from discord import Colour, Embed, Member, User
from discord.ext.commands import Bot, Cog, Command, Context, clean_content, command, group
+from bot.cogs.extensions import Extension
from bot.cogs.watchchannels.watchchannel import proxy_user
from bot.converters import TagNameConverter
from bot.pagination import LinePaginator
@@ -78,15 +79,15 @@ class Alias (Cog):
"""Alias for invoking <prefix>site faq."""
await self.invoke(ctx, "site faq")
- @command(name="rules", hidden=True)
- async def site_rules_alias(self, ctx: Context) -> None:
+ @command(name="rules", aliases=("rule",), hidden=True)
+ async def site_rules_alias(self, ctx: Context, *rules: int) -> None:
"""Alias for invoking <prefix>site rules."""
- await self.invoke(ctx, "site rules")
+ await self.invoke(ctx, "site rules", *rules)
@command(name="reload", hidden=True)
- async def cogs_reload_alias(self, ctx: Context, *, cog_name: str) -> None:
- """Alias for invoking <prefix>cogs reload [cog_name]."""
- await self.invoke(ctx, "cogs reload", cog_name)
+ async def extensions_reload_alias(self, ctx: Context, *extensions: Extension) -> None:
+ """Alias for invoking <prefix>extensions reload [extensions...]."""
+ await self.invoke(ctx, "extensions reload", *extensions)
@command(name="defon", hidden=True)
async def defcon_enable_alias(self, ctx: Context) -> None:
diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py
new file mode 100644
index 000000000..ababd6f18
--- /dev/null
+++ b/bot/cogs/antimalware.py
@@ -0,0 +1,56 @@
+import logging
+
+from discord import Message, NotFound
+from discord.ext.commands import Bot, Cog
+
+from bot.constants import AntiMalware as AntiMalwareConfig, Channels
+
+log = logging.getLogger(__name__)
+
+
+class AntiMalware(Cog):
+ """Delete messages which contain attachments with non-whitelisted file extensions."""
+
+ def __init__(self, bot: Bot):
+ self.bot = bot
+
+ @Cog.listener()
+ async def on_message(self, message: Message) -> None:
+ """Identify messages with prohibited attachments."""
+ rejected_attachments = False
+ detected_pyfile = False
+ for attachment in message.attachments:
+ if attachment.filename.lower().endswith('.py'):
+ detected_pyfile = True
+ break # Other detections irrelevant because we prioritize the .py message.
+ if not attachment.filename.lower().endswith(tuple(AntiMalwareConfig.whitelist)):
+ rejected_attachments = True
+
+ if detected_pyfile or rejected_attachments:
+ # Send a message to the user indicating the problem (with special treatment for .py)
+ author = message.author
+ if detected_pyfile:
+ msg = (
+ f"{author.mention}, it looks like you tried to attach a Python file - please "
+ f"use a code-pasting service such as https://paste.pythondiscord.com/ instead."
+ )
+ else:
+ meta_channel = self.bot.get_channel(Channels.meta)
+ msg = (
+ f"{author.mention}, it looks like you tried to attach a file type we don't "
+ f"allow. Feel free to ask in {meta_channel.mention} if you think this is a mistake."
+ )
+
+ await message.channel.send(msg)
+
+ # Delete the offending message:
+ try:
+ await message.delete()
+ except NotFound:
+ log.info(f"Tried to delete message `{message.id}`, but message could not be found.")
+
+
+def setup(bot: Bot) -> None:
+ """Antimalware cog load."""
+ bot.add_cog(AntiMalware(bot))
+ log.info("Cog loaded: AntiMalware")
diff --git a/bot/cogs/antispam.py b/bot/cogs/antispam.py
index 8dfa0ad05..1340eb608 100644
--- a/bot/cogs/antispam.py
+++ b/bot/cogs/antispam.py
@@ -10,7 +10,7 @@ from discord import Colour, Member, Message, NotFound, Object, TextChannel
from discord.ext.commands import Bot, Cog
from bot import rules
-from bot.cogs.modlog import ModLog
+from bot.cogs.moderation import ModLog
from bot.constants import (
AntiSpam as AntiSpamConfig, Channels,
Colours, DEBUG_MODE, Event, Filter,
@@ -59,7 +59,7 @@ class DeletionContext:
async def upload_messages(self, actor_id: int, modlog: ModLog) -> None:
"""Method that takes care of uploading the queue and posting modlog alert."""
- triggered_by_users = ", ".join(f"{m.display_name}#{m.discriminator} (`{m.id}`)" for m in self.members.values())
+ triggered_by_users = ", ".join(f"{m} (`{m.id}`)" for m in self.members.values())
mod_alert_message = (
f"**Triggered by:** {triggered_by_users}\n"
@@ -107,14 +107,16 @@ class AntiSpam(Cog):
self.message_deletion_queue = dict()
self.queue_consumption_tasks = dict()
+ self.bot.loop.create_task(self.alert_on_validation_error())
+
@property
def mod_log(self) -> ModLog:
"""Allows for easy access of the ModLog cog."""
return self.bot.get_cog("ModLog")
- @Cog.listener()
- async def on_ready(self) -> None:
+ async def alert_on_validation_error(self) -> None:
"""Unloads the cog and alerts admins if configuration validation failed."""
+ await self.bot.wait_until_ready()
if self.validation_errors:
body = "**The following errors were encountered:**\n"
body += "\n".join(f"- {error}" for error in self.validation_errors.values())
@@ -207,8 +209,10 @@ class AntiSpam(Cog):
if not any(role.id == self.muted_role.id for role in member.roles):
remove_role_after = AntiSpamConfig.punishment['remove_after']
- # We need context, let's get it
+ # Get context and make sure the bot becomes the actor of infraction by patching the `author` attributes
context = await self.bot.get_context(msg)
+ context.author = self.bot.user
+ context.message.author = self.bot.user
# Since we're going to invoke the tempmute command directly, we need to manually call the converter.
dt_remove_role_after = await self.expiration_date_converter.convert(context, f"{remove_role_after}S")
diff --git a/bot/cogs/clean.py b/bot/cogs/clean.py
index 1c0c9a7a8..dca411d01 100644
--- a/bot/cogs/clean.py
+++ b/bot/cogs/clean.py
@@ -6,7 +6,7 @@ from typing import Optional
from discord import Colour, Embed, Message, User
from discord.ext.commands import Bot, Cog, Context, group
-from bot.cogs.modlog import ModLog
+from bot.cogs.moderation import ModLog
from bot.constants import (
Channels, CleanMessages, Colours, Event,
Icons, MODERATION_ROLES, NEGATIVE_REPLIES
diff --git a/bot/cogs/cogs.py b/bot/cogs/cogs.py
deleted file mode 100644
index 1f6ccd09c..000000000
--- a/bot/cogs/cogs.py
+++ /dev/null
@@ -1,298 +0,0 @@
-import logging
-import os
-
-from discord import Colour, Embed
-from discord.ext.commands import Bot, Cog, Context, group
-
-from bot.constants import (
- Emojis, MODERATION_ROLES, Roles, URLs
-)
-from bot.decorators import with_role
-from bot.pagination import LinePaginator
-
-log = logging.getLogger(__name__)
-
-KEEP_LOADED = ["bot.cogs.cogs", "bot.cogs.modlog"]
-
-
-class Cogs(Cog):
- """Cog management commands."""
-
- def __init__(self, bot: Bot):
- self.bot = bot
- self.cogs = {}
-
- # Load up the cog names
- log.info("Initializing cog names...")
- for filename in os.listdir("bot/cogs"):
- if filename.endswith(".py") and "_" not in filename:
- if os.path.isfile(f"bot/cogs/{filename}"):
- cog = filename[:-3]
-
- self.cogs[cog] = f"bot.cogs.{cog}"
-
- # Allow reverse lookups by reversing the pairs
- self.cogs.update({v: k for k, v in self.cogs.items()})
-
- @group(name='cogs', aliases=('c',), invoke_without_command=True)
- @with_role(*MODERATION_ROLES, Roles.core_developer)
- async def cogs_group(self, ctx: Context) -> None:
- """Load, unload, reload, and list active cogs."""
- await ctx.invoke(self.bot.get_command("help"), "cogs")
-
- @cogs_group.command(name='load', aliases=('l',))
- @with_role(*MODERATION_ROLES, Roles.core_developer)
- async def load_command(self, ctx: Context, cog: str) -> None:
- """
- Load up an unloaded cog, given the module containing it.
-
- You can specify the cog name for any cogs that are placed directly within `!cogs`, or specify the
- entire module directly.
- """
- cog = cog.lower()
-
- embed = Embed()
- embed.colour = Colour.red()
-
- embed.set_author(
- name="Python Bot (Cogs)",
- url=URLs.github_bot_repo,
- icon_url=URLs.bot_avatar
- )
-
- if cog in self.cogs:
- full_cog = self.cogs[cog]
- elif "." in cog:
- full_cog = cog
- else:
- full_cog = None
- log.warning(f"{ctx.author} requested we load the '{cog}' cog, but that cog doesn't exist.")
- embed.description = f"Unknown cog: {cog}"
-
- if full_cog:
- if full_cog not in self.bot.extensions:
- try:
- self.bot.load_extension(full_cog)
- except ImportError:
- log.exception(f"{ctx.author} requested we load the '{cog}' cog, "
- f"but the cog module {full_cog} could not be found!")
- embed.description = f"Invalid cog: {cog}\n\nCould not find cog module {full_cog}"
- except Exception as e:
- log.exception(f"{ctx.author} requested we load the '{cog}' cog, "
- "but the loading failed")
- embed.description = f"Failed to load cog: {cog}\n\n{e.__class__.__name__}: {e}"
- else:
- log.debug(f"{ctx.author} requested we load the '{cog}' cog. Cog loaded!")
- embed.description = f"Cog loaded: {cog}"
- embed.colour = Colour.green()
- else:
- log.warning(f"{ctx.author} requested we load the '{cog}' cog, but the cog was already loaded!")
- embed.description = f"Cog {cog} is already loaded"
-
- await ctx.send(embed=embed)
-
- @cogs_group.command(name='unload', aliases=('ul',))
- @with_role(*MODERATION_ROLES, Roles.core_developer)
- async def unload_command(self, ctx: Context, cog: str) -> None:
- """
- Unload an already-loaded cog, given the module containing it.
-
- You can specify the cog name for any cogs that are placed directly within `!cogs`, or specify the
- entire module directly.
- """
- cog = cog.lower()
-
- embed = Embed()
- embed.colour = Colour.red()
-
- embed.set_author(
- name="Python Bot (Cogs)",
- url=URLs.github_bot_repo,
- icon_url=URLs.bot_avatar
- )
-
- if cog in self.cogs:
- full_cog = self.cogs[cog]
- elif "." in cog:
- full_cog = cog
- else:
- full_cog = None
- log.warning(f"{ctx.author} requested we unload the '{cog}' cog, but that cog doesn't exist.")
- embed.description = f"Unknown cog: {cog}"
-
- if full_cog:
- if full_cog in KEEP_LOADED:
- log.warning(f"{ctx.author} requested we unload `{full_cog}`, that sneaky pete. We said no.")
- embed.description = f"You may not unload `{full_cog}`!"
- elif full_cog in self.bot.extensions:
- try:
- self.bot.unload_extension(full_cog)
- except Exception as e:
- log.exception(f"{ctx.author} requested we unload the '{cog}' cog, "
- "but the unloading failed")
- embed.description = f"Failed to unload cog: {cog}\n\n```{e}```"
- else:
- log.debug(f"{ctx.author} requested we unload the '{cog}' cog. Cog unloaded!")
- embed.description = f"Cog unloaded: {cog}"
- embed.colour = Colour.green()
- else:
- log.warning(f"{ctx.author} requested we unload the '{cog}' cog, but the cog wasn't loaded!")
- embed.description = f"Cog {cog} is not loaded"
-
- await ctx.send(embed=embed)
-
- @cogs_group.command(name='reload', aliases=('r',))
- @with_role(*MODERATION_ROLES, Roles.core_developer)
- async def reload_command(self, ctx: Context, cog: str) -> None:
- """
- Reload an unloaded cog, given the module containing it.
-
- You can specify the cog name for any cogs that are placed directly within `!cogs`, or specify the
- entire module directly.
-
- If you specify "*" as the cog, every cog currently loaded will be unloaded, and then every cog present in the
- bot/cogs directory will be loaded.
- """
- cog = cog.lower()
-
- embed = Embed()
- embed.colour = Colour.red()
-
- embed.set_author(
- name="Python Bot (Cogs)",
- url=URLs.github_bot_repo,
- icon_url=URLs.bot_avatar
- )
-
- if cog == "*":
- full_cog = cog
- elif cog in self.cogs:
- full_cog = self.cogs[cog]
- elif "." in cog:
- full_cog = cog
- else:
- full_cog = None
- log.warning(f"{ctx.author} requested we reload the '{cog}' cog, but that cog doesn't exist.")
- embed.description = f"Unknown cog: {cog}"
-
- if full_cog:
- if full_cog == "*":
- all_cogs = [
- f"bot.cogs.{fn[:-3]}" for fn in os.listdir("bot/cogs")
- if os.path.isfile(f"bot/cogs/{fn}") and fn.endswith(".py") and "_" not in fn
- ]
-
- failed_unloads = {}
- failed_loads = {}
-
- unloaded = 0
- loaded = 0
-
- for loaded_cog in self.bot.extensions.copy().keys():
- try:
- self.bot.unload_extension(loaded_cog)
- except Exception as e:
- failed_unloads[loaded_cog] = f"{e.__class__.__name__}: {e}"
- else:
- unloaded += 1
-
- for unloaded_cog in all_cogs:
- try:
- self.bot.load_extension(unloaded_cog)
- except Exception as e:
- failed_loads[unloaded_cog] = f"{e.__class__.__name__}: {e}"
- else:
- loaded += 1
-
- lines = [
- "**All cogs reloaded**",
- f"**Unloaded**: {unloaded} / **Loaded**: {loaded}"
- ]
-
- if failed_unloads:
- lines.append("\n**Unload failures**")
-
- for cog, error in failed_unloads:
- lines.append(f"{Emojis.status_dnd} **{cog}:** `{error}`")
-
- if failed_loads:
- lines.append("\n**Load failures**")
-
- for cog, error in failed_loads.items():
- lines.append(f"{Emojis.status_dnd} **{cog}:** `{error}`")
-
- log.debug(f"{ctx.author} requested we reload all cogs. Here are the results: \n"
- f"{lines}")
-
- await LinePaginator.paginate(lines, ctx, embed, empty=False)
- return
-
- elif full_cog in self.bot.extensions:
- try:
- self.bot.unload_extension(full_cog)
- self.bot.load_extension(full_cog)
- except Exception as e:
- log.exception(f"{ctx.author} requested we reload the '{cog}' cog, "
- "but the unloading failed")
- embed.description = f"Failed to reload cog: {cog}\n\n```{e}```"
- else:
- log.debug(f"{ctx.author} requested we reload the '{cog}' cog. Cog reloaded!")
- embed.description = f"Cog reload: {cog}"
- embed.colour = Colour.green()
- else:
- log.warning(f"{ctx.author} requested we reload the '{cog}' cog, but the cog wasn't loaded!")
- embed.description = f"Cog {cog} is not loaded"
-
- await ctx.send(embed=embed)
-
- @cogs_group.command(name='list', aliases=('all',))
- @with_role(*MODERATION_ROLES, Roles.core_developer)
- async def list_command(self, ctx: Context) -> None:
- """
- Get a list of all cogs, including their loaded status.
-
- Gray indicates that the cog is unloaded. Green indicates that the cog is currently loaded.
- """
- embed = Embed()
- lines = []
- cogs = {}
-
- embed.colour = Colour.blurple()
- embed.set_author(
- name="Python Bot (Cogs)",
- url=URLs.github_bot_repo,
- icon_url=URLs.bot_avatar
- )
-
- for key, _value in self.cogs.items():
- if "." not in key:
- continue
-
- if key in self.bot.extensions:
- cogs[key] = True
- else:
- cogs[key] = False
-
- for key in self.bot.extensions.keys():
- if key not in self.cogs:
- cogs[key] = True
-
- for cog, loaded in sorted(cogs.items(), key=lambda x: x[0]):
- if cog in self.cogs:
- cog = self.cogs[cog]
-
- if loaded:
- status = Emojis.status_online
- else:
- status = Emojis.status_offline
-
- lines.append(f"{status} {cog}")
-
- log.debug(f"{ctx.author} requested a list of all cogs. Returning a paginated list.")
- await LinePaginator.paginate(lines, ctx, embed, max_size=300, empty=False)
-
-
-def setup(bot: Bot) -> None:
- """Cogs cog load."""
- bot.add_cog(Cogs(bot))
- log.info("Cog loaded: Cogs")
diff --git a/bot/cogs/defcon.py b/bot/cogs/defcon.py
index 048d8a683..bedd70c86 100644
--- a/bot/cogs/defcon.py
+++ b/bot/cogs/defcon.py
@@ -1,10 +1,14 @@
+from __future__ import annotations
+
import logging
+from collections import namedtuple
from datetime import datetime, timedelta
+from enum import Enum
from discord import Colour, Embed, Member
from discord.ext.commands import Bot, Cog, Context, group
-from bot.cogs.modlog import ModLog
+from bot.cogs.moderation import ModLog
from bot.constants import Channels, Colours, Emojis, Event, Icons, Roles
from bot.decorators import with_role
@@ -24,6 +28,16 @@ will be resolved soon. In the meantime, please feel free to peruse the resources
BASE_CHANNEL_TOPIC = "Python Discord Defense Mechanism"
+class Action(Enum):
+ """Defcon Action."""
+
+ ActionInfo = namedtuple('LogInfoDetails', ['icon', 'color', 'template'])
+
+ ENABLED = ActionInfo(Icons.defcon_enabled, Colours.soft_green, "**Days:** {days}\n\n")
+ DISABLED = ActionInfo(Icons.defcon_disabled, Colours.soft_red, "")
+ UPDATED = ActionInfo(Icons.defcon_updated, Colour.blurple(), "**Days:** {days}\n\n")
+
+
class Defcon(Cog):
"""Time-sensitive server defense mechanisms."""
@@ -35,15 +49,18 @@ class Defcon(Cog):
self.channel = None
self.days = timedelta(days=0)
+ self.bot.loop.create_task(self.sync_settings())
+
@property
def mod_log(self) -> ModLog:
"""Get currently loaded ModLog cog instance."""
return self.bot.get_cog("ModLog")
- @Cog.listener()
- async def on_ready(self) -> None:
+ async def sync_settings(self) -> None:
"""On cog load, try to synchronize DEFCON settings to the API."""
+ await self.bot.wait_until_ready()
self.channel = await self.bot.fetch_channel(Channels.defcon)
+
try:
response = await self.bot.api_client.get('bot/bot-settings/defcon')
data = response['data']
@@ -88,8 +105,7 @@ class Defcon(Cog):
await member.kick(reason="DEFCON active, user is too new")
message = (
- f"{member.name}#{member.discriminator} (`{member.id}`) "
- f"was denied entry because their account is too new."
+ f"{member} (`{member.id}`) was denied entry because their account is too new."
)
if not message_sent:
@@ -106,39 +122,39 @@ class Defcon(Cog):
"""Check the DEFCON status or run a subcommand."""
await ctx.invoke(self.bot.get_command("help"), "defcon")
- @defcon_group.command(name='enable', aliases=('on', 'e'))
- @with_role(Roles.admin, Roles.owner)
- async def enable_command(self, ctx: Context) -> None:
- """
- Enable DEFCON mode. Useful in a pinch, but be sure you know what you're doing!
-
- Currently, this just adds an account age requirement. Use !defcon days <int> to set how old an account must be,
- in days.
- """
- self.enabled = True
-
+ async def _defcon_action(self, ctx: Context, days: int, action: Action) -> None:
+ """Providing a structured way to do an defcon action."""
+ error = None
try:
await self.bot.api_client.put(
'bot/bot-settings/defcon',
json={
'name': 'defcon',
'data': {
- 'enabled': True,
# TODO: retrieve old days count
- 'days': 0
+ 'days': days,
+ 'enabled': action is not Action.DISABLED,
}
}
)
-
- except Exception as e:
+ except Exception as err:
log.exception("Unable to update DEFCON settings.")
- await ctx.send(self.build_defcon_msg("enabled", e))
- await self.send_defcon_log("enabled", ctx.author, e)
+ error = err
+ finally:
+ await ctx.send(self.build_defcon_msg(action, error))
+ await self.send_defcon_log(action, ctx.author, error)
- else:
- await ctx.send(self.build_defcon_msg("enabled"))
- await self.send_defcon_log("enabled", ctx.author)
+ @defcon_group.command(name='enable', aliases=('on', 'e'))
+ @with_role(Roles.admin, Roles.owner)
+ async def enable_command(self, ctx: Context) -> None:
+ """
+ Enable DEFCON mode. Useful in a pinch, but be sure you know what you're doing!
+ Currently, this just adds an account age requirement. Use !defcon days <int> to set how old an account must be,
+ in days.
+ """
+ self.enabled = True
+ await self._defcon_action(ctx, days=0, action=Action.ENABLED)
await self.update_channel_topic()
@defcon_group.command(name='disable', aliases=('off', 'd'))
@@ -146,26 +162,7 @@ class Defcon(Cog):
async def disable_command(self, ctx: Context) -> None:
"""Disable DEFCON mode. Useful in a pinch, but be sure you know what you're doing!"""
self.enabled = False
-
- try:
- await self.bot.api_client.put(
- 'bot/bot-settings/defcon',
- json={
- 'data': {
- 'days': 0,
- 'enabled': False
- },
- 'name': 'defcon'
- }
- )
- except Exception as e:
- log.exception("Unable to update DEFCON settings.")
- await ctx.send(self.build_defcon_msg("disabled", e))
- await self.send_defcon_log("disabled", ctx.author, e)
- else:
- await ctx.send(self.build_defcon_msg("disabled"))
- await self.send_defcon_log("disabled", ctx.author)
-
+ await self._defcon_action(ctx, days=0, action=Action.DISABLED)
await self.update_channel_topic()
@defcon_group.command(name='status', aliases=('s',))
@@ -185,30 +182,8 @@ class Defcon(Cog):
async def days_command(self, ctx: Context, days: int) -> None:
"""Set how old an account must be to join the server, in days, with DEFCON mode enabled."""
self.days = timedelta(days=days)
-
- try:
- await self.bot.api_client.put(
- 'bot/bot-settings/defcon',
- json={
- 'data': {
- 'days': days,
- 'enabled': True
- },
- 'name': 'defcon'
- }
- )
- except Exception as e:
- log.exception("Unable to update DEFCON settings.")
- await ctx.send(self.build_defcon_msg("updated", e))
- await self.send_defcon_log("updated", ctx.author, e)
- else:
- await ctx.send(self.build_defcon_msg("updated"))
- await self.send_defcon_log("updated", ctx.author)
-
- # Enable DEFCON if it's not already
- if not self.enabled:
- self.enabled = True
-
+ self.enabled = True
+ await self._defcon_action(ctx, days=days, action=Action.UPDATED)
await self.update_channel_topic()
async def update_channel_topic(self) -> None:
@@ -222,20 +197,16 @@ class Defcon(Cog):
self.mod_log.ignore(Event.guild_channel_update, Channels.defcon)
await self.channel.edit(topic=new_topic)
- def build_defcon_msg(self, change: str, e: Exception = None) -> str:
- """
- Build in-channel response string for DEFCON action.
-
- `change` string may be one of the following: ('enabled', 'disabled', 'updated')
- """
- if change.lower() == "enabled":
+ def build_defcon_msg(self, action: Action, e: Exception = None) -> str:
+ """Build in-channel response string for DEFCON action."""
+ if action is Action.ENABLED:
msg = f"{Emojis.defcon_enabled} DEFCON enabled.\n\n"
- elif change.lower() == "disabled":
+ elif action is Action.DISABLED:
msg = f"{Emojis.defcon_disabled} DEFCON disabled.\n\n"
- elif change.lower() == "updated":
+ elif action is Action.UPDATED:
msg = (
- f"{Emojis.defcon_updated} DEFCON days updated; accounts must be {self.days} "
- "days old to join the server.\n\n"
+ f"{Emojis.defcon_updated} DEFCON days updated; accounts must be {self.days.days} "
+ f"day{'s' if self.days.days > 1 else ''} old to join the server.\n\n"
)
if e:
@@ -246,28 +217,14 @@ class Defcon(Cog):
return msg
- async def send_defcon_log(self, change: str, actor: Member, e: Exception = None) -> None:
- """
- Send log message for DEFCON action.
-
- `change` string may be one of the following: ('enabled', 'disabled', 'updated')
- """
- log_msg = f"**Staffer:** {actor.name}#{actor.discriminator} (`{actor.id}`)\n"
-
- if change.lower() == "enabled":
- icon = Icons.defcon_enabled
- color = Colours.soft_green
- status_msg = "DEFCON enabled"
- log_msg += f"**Days:** {self.days.days}\n\n"
- elif change.lower() == "disabled":
- icon = Icons.defcon_disabled
- color = Colours.soft_red
- status_msg = "DEFCON enabled"
- elif change.lower() == "updated":
- icon = Icons.defcon_updated
- color = Colour.blurple()
- status_msg = "DEFCON updated"
- log_msg += f"**Days:** {self.days.days}\n\n"
+ async def send_defcon_log(self, action: Action, actor: Member, e: Exception = None) -> None:
+ """Send log message for DEFCON action."""
+ info = action.value
+ log_msg: str = (
+ f"**Staffer:** {actor.mention} {actor} (`{actor.id}`)\n"
+ f"{info.template.format(days=self.days.days)}"
+ )
+ status_msg = f"DEFCON {action.name.lower()}"
if e:
log_msg += (
@@ -275,7 +232,7 @@ class Defcon(Cog):
f"```py\n{e}\n```"
)
- await self.mod_log.send_log_message(icon, color, status_msg, log_msg)
+ await self.mod_log.send_log_message(info.icon, info.color, status_msg, log_msg)
def setup(bot: Bot) -> None:
diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py
index c9e6b3b91..65cabe46f 100644
--- a/bot/cogs/doc.py
+++ b/bot/cogs/doc.py
@@ -126,9 +126,11 @@ class Doc(commands.Cog):
self.bot = bot
self.inventories = {}
- @commands.Cog.listener()
- async def on_ready(self) -> None:
- """Refresh documentation inventory."""
+ self.bot.loop.create_task(self.init_refresh_inventory())
+
+ async def init_refresh_inventory(self) -> None:
+ """Refresh documentation inventory on cog initialization."""
+ await self.bot.wait_until_ready()
await self.refresh_inventory()
async def update_single(
@@ -207,6 +209,9 @@ class Doc(commands.Cog):
symbol_heading = soup.find(id=symbol_id)
signature_buffer = []
+ if symbol_heading is None:
+ return None
+
# Traverse the tags of the signature header and ignore any
# unwanted symbols from it. Add all of it to a temporary buffer.
for tag in symbol_heading.strings:
@@ -331,8 +336,7 @@ class Doc(commands.Cog):
await self.bot.api_client.post('bot/documentation-links', json=body)
log.info(
- f"User @{ctx.author.name}#{ctx.author.discriminator} ({ctx.author.id}) "
- "added a new documentation package:\n"
+ f"User @{ctx.author} ({ctx.author.id}) added a new documentation package:\n"
f"Package name: {package_name}\n"
f"Base url: {base_url}\n"
f"Inventory URL: {inventory_url}"
diff --git a/bot/cogs/extensions.py b/bot/cogs/extensions.py
new file mode 100644
index 000000000..bb66e0b8e
--- /dev/null
+++ b/bot/cogs/extensions.py
@@ -0,0 +1,236 @@
+import functools
+import logging
+import typing as t
+from enum import Enum
+from pkgutil import iter_modules
+
+from discord import Colour, Embed
+from discord.ext import commands
+from discord.ext.commands import Bot, Context, group
+
+from bot.constants import Emojis, MODERATION_ROLES, Roles, URLs
+from bot.pagination import LinePaginator
+from bot.utils.checks import with_role_check
+
+log = logging.getLogger(__name__)
+
+UNLOAD_BLACKLIST = {"bot.cogs.extensions", "bot.cogs.modlog"}
+EXTENSIONS = frozenset(
+ ext.name
+ for ext in iter_modules(("bot/cogs",), "bot.cogs.")
+ if ext.name[-1] != "_"
+)
+
+
+class Action(Enum):
+ """Represents an action to perform on an extension."""
+
+ # Need to be partial otherwise they are considered to be function definitions.
+ LOAD = functools.partial(Bot.load_extension)
+ UNLOAD = functools.partial(Bot.unload_extension)
+ RELOAD = functools.partial(Bot.reload_extension)
+
+
+class Extension(commands.Converter):
+ """
+ Fully qualify the name of an extension and ensure it exists.
+
+ The * and ** values bypass this when used with the reload command.
+ """
+
+ async def convert(self, ctx: Context, argument: str) -> str:
+ """Fully qualify the name of an extension and ensure it exists."""
+ # Special values to reload all extensions
+ if argument == "*" or argument == "**":
+ return argument
+
+ argument = argument.lower()
+
+ if "." not in argument:
+ argument = f"bot.cogs.{argument}"
+
+ if argument in EXTENSIONS:
+ return argument
+ else:
+ raise commands.BadArgument(f":x: Could not find the extension `{argument}`.")
+
+
+class Extensions(commands.Cog):
+ """Extension management commands."""
+
+ def __init__(self, bot: Bot):
+ self.bot = bot
+
+ @group(name="extensions", aliases=("ext", "exts", "c", "cogs"), invoke_without_command=True)
+ async def extensions_group(self, ctx: Context) -> None:
+ """Load, unload, reload, and list loaded extensions."""
+ await ctx.invoke(self.bot.get_command("help"), "extensions")
+
+ @extensions_group.command(name="load", aliases=("l",))
+ async def load_command(self, ctx: Context, *extensions: Extension) -> None:
+ """
+ Load extensions given their fully qualified or unqualified names.
+
+ If '\*' or '\*\*' is given as the name, all unloaded extensions will be loaded.
+ """ # noqa: W605
+ if not extensions:
+ await ctx.invoke(self.bot.get_command("help"), "extensions load")
+ return
+
+ if "*" in extensions or "**" in extensions:
+ extensions = set(EXTENSIONS) - set(self.bot.extensions.keys())
+
+ msg = self.batch_manage(Action.LOAD, *extensions)
+ await ctx.send(msg)
+
+ @extensions_group.command(name="unload", aliases=("ul",))
+ async def unload_command(self, ctx: Context, *extensions: Extension) -> None:
+ """
+ Unload currently loaded extensions given their fully qualified or unqualified names.
+
+ If '\*' or '\*\*' is given as the name, all loaded extensions will be unloaded.
+ """ # noqa: W605
+ if not extensions:
+ await ctx.invoke(self.bot.get_command("help"), "extensions unload")
+ return
+
+ blacklisted = "\n".join(UNLOAD_BLACKLIST & set(extensions))
+
+ if blacklisted:
+ msg = f":x: The following extension(s) may not be unloaded:```{blacklisted}```"
+ else:
+ if "*" in extensions or "**" in extensions:
+ extensions = set(self.bot.extensions.keys()) - UNLOAD_BLACKLIST
+
+ msg = self.batch_manage(Action.UNLOAD, *extensions)
+
+ await ctx.send(msg)
+
+ @extensions_group.command(name="reload", aliases=("r",))
+ async def reload_command(self, ctx: Context, *extensions: Extension) -> None:
+ """
+ Reload extensions given their fully qualified or unqualified names.
+
+ If an extension fails to be reloaded, it will be rolled-back to the prior working state.
+
+ If '\*' is given as the name, all currently loaded extensions will be reloaded.
+ If '\*\*' is given as the name, all extensions, including unloaded ones, will be reloaded.
+ """ # noqa: W605
+ if not extensions:
+ await ctx.invoke(self.bot.get_command("help"), "extensions reload")
+ return
+
+ if "**" in extensions:
+ extensions = EXTENSIONS
+ elif "*" in extensions:
+ extensions = set(self.bot.extensions.keys()) | set(extensions)
+ extensions.remove("*")
+
+ msg = self.batch_manage(Action.RELOAD, *extensions)
+
+ await ctx.send(msg)
+
+ @extensions_group.command(name="list", aliases=("all",))
+ async def list_command(self, ctx: Context) -> None:
+ """
+ Get a list of all extensions, including their loaded status.
+
+ Grey indicates that the extension is unloaded.
+ Green indicates that the extension is currently loaded.
+ """
+ embed = Embed()
+ lines = []
+
+ embed.colour = Colour.blurple()
+ embed.set_author(
+ name="Extensions List",
+ url=URLs.github_bot_repo,
+ icon_url=URLs.bot_avatar
+ )
+
+ for ext in sorted(list(EXTENSIONS)):
+ if ext in self.bot.extensions:
+ status = Emojis.status_online
+ else:
+ status = Emojis.status_offline
+
+ ext = ext.rsplit(".", 1)[1]
+ lines.append(f"{status} {ext}")
+
+ log.debug(f"{ctx.author} requested a list of all cogs. Returning a paginated list.")
+ await LinePaginator.paginate(lines, ctx, embed, max_size=300, empty=False)
+
+ def batch_manage(self, action: Action, *extensions: str) -> str:
+ """
+ Apply an action to multiple extensions and return a message with the results.
+
+ If only one extension is given, it is deferred to `manage()`.
+ """
+ if len(extensions) == 1:
+ msg, _ = self.manage(action, extensions[0])
+ return msg
+
+ verb = action.name.lower()
+ failures = {}
+
+ for extension in extensions:
+ _, error = self.manage(action, extension)
+ if error:
+ failures[extension] = error
+
+ emoji = ":x:" if failures else ":ok_hand:"
+ msg = f"{emoji} {len(extensions) - len(failures)} / {len(extensions)} extensions {verb}ed."
+
+ if failures:
+ failures = "\n".join(f"{ext}\n {err}" for ext, err in failures.items())
+ msg += f"\nFailures:```{failures}```"
+
+ log.debug(f"Batch {verb}ed extensions.")
+
+ return msg
+
+ def manage(self, action: Action, ext: str) -> t.Tuple[str, t.Optional[str]]:
+ """Apply an action to an extension and return the status message and any error message."""
+ verb = action.name.lower()
+ error_msg = None
+
+ try:
+ action.value(self.bot, ext)
+ except (commands.ExtensionAlreadyLoaded, commands.ExtensionNotLoaded):
+ if action is Action.RELOAD:
+ # When reloading, just load the extension if it was not loaded.
+ return self.manage(Action.LOAD, ext)
+
+ msg = f":x: Extension `{ext}` is already {verb}ed."
+ log.debug(msg[4:])
+ except Exception as e:
+ if hasattr(e, "original"):
+ e = e.original
+
+ log.exception(f"Extension '{ext}' failed to {verb}.")
+
+ error_msg = f"{e.__class__.__name__}: {e}"
+ msg = f":x: Failed to {verb} extension `{ext}`:\n```{error_msg}```"
+ else:
+ msg = f":ok_hand: Extension successfully {verb}ed: `{ext}`."
+ log.debug(msg[10:])
+
+ return msg, error_msg
+
+ # This cannot be static (must have a __func__ attribute).
+ def cog_check(self, ctx: Context) -> bool:
+ """Only allow moderators and core developers to invoke the commands in this cog."""
+ return with_role_check(ctx, *MODERATION_ROLES, Roles.core_developer)
+
+ # This cannot be static (must have a __func__ attribute).
+ async def cog_command_error(self, ctx: Context, error: Exception) -> None:
+ """Handle BadArgument errors locally to prevent the help command from showing."""
+ if isinstance(error, commands.BadArgument):
+ await ctx.send(str(error))
+ error.handled = True
+
+
+def setup(bot: Bot) -> None:
+ """Load the Extensions cog."""
+ bot.add_cog(Extensions(bot))
+ log.info("Cog loaded: Extensions")
diff --git a/bot/cogs/filtering.py b/bot/cogs/filtering.py
index bd8c6ed67..4195783f1 100644
--- a/bot/cogs/filtering.py
+++ b/bot/cogs/filtering.py
@@ -7,9 +7,9 @@ from dateutil.relativedelta import relativedelta
from discord import Colour, DMChannel, Member, Message, TextChannel
from discord.ext.commands import Bot, Cog
-from bot.cogs.modlog import ModLog
+from bot.cogs.moderation import ModLog
from bot.constants import (
- Channels, Colours, DEBUG_MODE,
+ Channels, Colours,
Filter, Icons, URLs
)
@@ -63,7 +63,7 @@ class Filtering(Cog):
"content_only": True,
"user_notification": Filter.notify_user_invites,
"notification_msg": (
- f"Per Rule 10, your invite link has been removed. {_staff_mistake_str}\n\n"
+ f"Per Rule 6, your invite link has been removed. {_staff_mistake_str}\n\n"
r"Our server rules can be found here: <https://pythondiscord.com/pages/rules>"
)
},
@@ -136,10 +136,6 @@ class Filtering(Cog):
and not msg.author.bot # Author not a bot
)
- # If we're running the bot locally, ignore role whitelist and only listen to #dev-test
- if DEBUG_MODE:
- filter_message = not msg.author.bot and msg.channel.id == Channels.devtest
-
# If none of the above, we can start filtering.
if filter_message:
for filter_name, _filter in self.filters.items():
@@ -154,11 +150,11 @@ class Filtering(Cog):
# Does the filter only need the message content or the full message?
if _filter["content_only"]:
- triggered = await _filter["function"](msg.content)
+ match = await _filter["function"](msg.content)
else:
- triggered = await _filter["function"](msg)
+ match = await _filter["function"](msg)
- if triggered:
+ if match:
# If this is a filter (not a watchlist), we should delete the message.
if _filter["type"] == "filter":
try:
@@ -184,12 +180,23 @@ class Filtering(Cog):
else:
channel_str = f"in {msg.channel.mention}"
+ # Word and match stats for watch_words and watch_tokens
+ if filter_name in ("watch_words", "watch_tokens"):
+ surroundings = match.string[max(match.start() - 10, 0): match.end() + 10]
+ message_content = (
+ f"**Match:** '{match[0]}'\n"
+ f"**Location:** '...{surroundings}...'\n"
+ f"\n**Original Message:**\n{msg.content}"
+ )
+ else: # Use content of discord Message
+ message_content = msg.content
+
message = (
f"The {filter_name} {_filter['type']} was triggered "
- f"by **{msg.author.name}#{msg.author.discriminator}** "
+ f"by **{msg.author}** "
f"(`{msg.author.id}`) {channel_str} with [the "
f"following message]({msg.jump_url}):\n\n"
- f"{msg.content}"
+ f"{message_content}"
)
log.debug(message)
@@ -199,7 +206,7 @@ class Filtering(Cog):
if filter_name == "filter_invites":
additional_embeds = []
- for invite, data in triggered.items():
+ for invite, data in match.items():
embed = discord.Embed(description=(
f"**Members:**\n{data['members']}\n"
f"**Active:**\n{data['active']}"
@@ -230,31 +237,33 @@ class Filtering(Cog):
break # We don't want multiple filters to trigger
@staticmethod
- async def _has_watchlist_words(text: str) -> bool:
+ async def _has_watchlist_words(text: str) -> Union[bool, re.Match]:
"""
Returns True if the text contains one of the regular expressions from the word_watchlist in our filter config.
Only matches words with boundaries before and after the expression.
"""
for regex_pattern in WORD_WATCHLIST_PATTERNS:
- if regex_pattern.search(text):
- return True
+ match = regex_pattern.search(text)
+ if match:
+ return match # match objects always have a boolean value of True
return False
@staticmethod
- async def _has_watchlist_tokens(text: str) -> bool:
+ async def _has_watchlist_tokens(text: str) -> Union[bool, re.Match]:
"""
Returns True if the text contains one of the regular expressions from the token_watchlist in our filter config.
This will match the expression even if it does not have boundaries before and after.
"""
for regex_pattern in TOKEN_WATCHLIST_PATTERNS:
- if regex_pattern.search(text):
+ match = regex_pattern.search(text)
+ if match:
# Make sure it's not a URL
if not URL_RE.search(text):
- return True
+ return match # match objects always have a boolean value of True
return False
diff --git a/bot/cogs/free.py b/bot/cogs/free.py
index 269c5c1b9..82285656b 100644
--- a/bot/cogs/free.py
+++ b/bot/cogs/free.py
@@ -72,30 +72,27 @@ class Free(Cog):
# Display all potentially inactive channels
# in descending order of inactivity
if free_channels:
- embed.description += "**The following channel{0} look{1} free:**\n\n**".format(
- 's' if len(free_channels) > 1 else '',
- '' if len(free_channels) > 1 else 's'
- )
-
# Sort channels in descending order by seconds
# Get position in list, inactivity, and channel object
# For each channel, add to embed.description
sorted_channels = sorted(free_channels, key=itemgetter(0), reverse=True)
- for i, (inactive, channel) in enumerate(sorted_channels, 1):
+
+ for (inactive, channel) in sorted_channels[:3]:
minutes, seconds = divmod(inactive, 60)
if minutes > 59:
hours, minutes = divmod(minutes, 60)
- embed.description += f"{i}. {channel.mention} inactive for {hours}h{minutes}m{seconds}s\n\n"
+ embed.description += f"{channel.mention} **{hours}h {minutes}m {seconds}s** inactive\n"
else:
- embed.description += f"{i}. {channel.mention} inactive for {minutes}m{seconds}s\n\n"
+ embed.description += f"{channel.mention} **{minutes}m {seconds}s** inactive\n"
- embed.description += ("**\nThese channels aren't guaranteed to be free, "
- "so use your best judgement and check for yourself.")
+ embed.set_footer(text="Please confirm these channels are free before posting")
else:
- embed.description = ("**Doesn't look like any channels are available right now. "
- "You're welcome to check for yourself to be sure. "
- "If all channels are truly busy, please be patient "
- "as one will likely be available soon.**")
+ embed.description = (
+ "Doesn't look like any channels are available right now. "
+ "You're welcome to check for yourself to be sure. "
+ "If all channels are truly busy, please be patient "
+ "as one will likely be available soon."
+ )
await ctx.send(embed=embed)
diff --git a/bot/cogs/help.py b/bot/cogs/help.py
index 37d12b2d5..9607dbd8d 100644
--- a/bot/cogs/help.py
+++ b/bot/cogs/help.py
@@ -1,5 +1,4 @@
import asyncio
-import inspect
import itertools
from collections import namedtuple
from contextlib import suppress
@@ -61,6 +60,12 @@ class HelpSession:
The message object that's showing the help contents.
* destination: `discord.abc.Messageable`
Where the help message is to be sent to.
+
+ Cogs can be grouped into custom categories. All cogs with the same category will be displayed
+ under a single category name in the help output. Custom categories are defined inside the cogs
+ as a class attribute named `category`. A description can also be specified with the attribute
+ `category_description`. If a description is not found in at least one cog, the default will be
+ the regular description (class docstring) of the first cog found in the category.
"""
def __init__(
@@ -107,12 +112,31 @@ class HelpSession:
if command:
return command
- cog = self._bot.cogs.get(query)
- if cog:
+ # Find all cog categories that match.
+ cog_matches = []
+ description = None
+ for cog in self._bot.cogs.values():
+ if hasattr(cog, "category") and cog.category == query:
+ cog_matches.append(cog)
+ if hasattr(cog, "category_description"):
+ description = cog.category_description
+
+ # Try to search by cog name if no categories match.
+ if not cog_matches:
+ cog = self._bot.cogs.get(query)
+
+ # Don't consider it a match if the cog has a category.
+ if cog and not hasattr(cog, "category"):
+ cog_matches = [cog]
+
+ if cog_matches:
+ cog = cog_matches[0]
+ cmds = (cog.get_commands() for cog in cog_matches) # Commands of all cogs
+
return Cog(
- name=cog.qualified_name,
- description=inspect.getdoc(cog),
- commands=[c for c in self._bot.commands if c.cog is cog]
+ name=cog.category if hasattr(cog, "category") else cog.qualified_name,
+ description=description or cog.description,
+ commands=tuple(itertools.chain.from_iterable(cmds)) # Flatten the list
)
self._handle_not_found(query)
@@ -207,8 +231,16 @@ class HelpSession:
A zero width space is used as a prefix for results with no cogs to force them last in ordering.
"""
- cog = cmd.cog_name
- return f'**{cog}**' if cog else f'**\u200bNo Category:**'
+ if cmd.cog:
+ try:
+ if cmd.cog.category:
+ return f'**{cmd.cog.category}**'
+ except AttributeError:
+ pass
+
+ return f'**{cmd.cog_name}**'
+ else:
+ return "**\u200bNo Category:**"
def _get_command_params(self, cmd: Command) -> str:
"""
diff --git a/bot/cogs/information.py b/bot/cogs/information.py
index 60aec6219..530453600 100644
--- a/bot/cogs/information.py
+++ b/bot/cogs/information.py
@@ -1,12 +1,20 @@
+import colorsys
import logging
+import pprint
import textwrap
-
-from discord import CategoryChannel, Colour, Embed, Member, TextChannel, VoiceChannel
-from discord.ext.commands import Bot, Cog, Context, command
-
-from bot.constants import Channels, Emojis, MODERATION_ROLES, STAFF_ROLES
-from bot.decorators import InChannelCheckFailure, with_role
-from bot.utils.checks import with_role_check
+import typing
+from collections import defaultdict
+from typing import Any, Mapping, Optional
+
+import discord
+from discord import CategoryChannel, Colour, Embed, Member, Role, TextChannel, VoiceChannel, utils
+from discord.ext import commands
+from discord.ext.commands import Bot, BucketType, Cog, Context, command, group
+from discord.utils import escape_markdown
+
+from bot import constants
+from bot.decorators import InChannelCheckFailure, in_channel, with_role
+from bot.utils.checks import cooldown_with_role_bypass, with_role_check
from bot.utils.time import time_since
log = logging.getLogger(__name__)
@@ -18,7 +26,7 @@ class Information(Cog):
def __init__(self, bot: Bot):
self.bot = bot
- @with_role(*MODERATION_ROLES)
+ @with_role(*constants.MODERATION_ROLES)
@command(name="roles")
async def roles_info(self, ctx: Context) -> None:
"""Returns a list of all roles and their corresponding IDs."""
@@ -42,6 +50,52 @@ class Information(Cog):
await ctx.send(embed=embed)
+ @with_role(*constants.MODERATION_ROLES)
+ @command(name="role")
+ async def role_info(self, ctx: Context, *roles: typing.Union[Role, str]) -> None:
+ """
+ Return information on a role or list of roles.
+
+ To specify multiple roles just add to the arguments, delimit roles with spaces in them using quotation marks.
+ """
+ parsed_roles = []
+
+ for role_name in roles:
+ if isinstance(role_name, Role):
+ # Role conversion has already succeeded
+ parsed_roles.append(role_name)
+ continue
+
+ role = utils.find(lambda r: r.name.lower() == role_name.lower(), ctx.guild.roles)
+
+ if not role:
+ await ctx.send(f":x: Could not convert `{role_name}` to a role")
+ continue
+
+ parsed_roles.append(role)
+
+ for role in parsed_roles:
+ embed = Embed(
+ title=f"{role.name} info",
+ colour=role.colour,
+ )
+
+ embed.add_field(name="ID", value=role.id, inline=True)
+
+ embed.add_field(name="Colour (RGB)", value=f"#{role.colour.value:0>6x}", inline=True)
+
+ h, s, v = colorsys.rgb_to_hsv(*role.colour.to_rgb())
+
+ embed.add_field(name="Colour (HSV)", value=f"{h:.2f} {s:.2f} {v}", inline=True)
+
+ embed.add_field(name="Member count", value=len(role.members), inline=True)
+
+ embed.add_field(name="Position", value=role.position)
+
+ embed.add_field(name="Permission code", value=role.permissions.value, inline=True)
+
+ await ctx.send(embed=embed)
+
@command(name="server", aliases=["server_info", "guild", "guild_info"])
async def server_info(self, ctx: Context) -> None:
"""Returns an embed full of server information."""
@@ -96,10 +150,10 @@ class Information(Cog):
Channel categories: {category_channels}
**Members**
- {Emojis.status_online} {online}
- {Emojis.status_idle} {idle}
- {Emojis.status_dnd} {dnd}
- {Emojis.status_offline} {offline}
+ {constants.Emojis.status_online} {online}
+ {constants.Emojis.status_idle} {idle}
+ {constants.Emojis.status_dnd} {dnd}
+ {constants.Emojis.status_offline} {offline}
""")
)
@@ -108,78 +162,232 @@ class Information(Cog):
await ctx.send(embed=embed)
@command(name="user", aliases=["user_info", "member", "member_info"])
- async def user_info(self, ctx: Context, user: Member = None, hidden: bool = False) -> None:
+ async def user_info(self, ctx: Context, user: Member = None) -> None:
"""Returns info about a user."""
if user is None:
user = ctx.author
# Do a role check if this is being executed on someone other than the caller
- if user != ctx.author and not with_role_check(ctx, *MODERATION_ROLES):
+ if user != ctx.author and not with_role_check(ctx, *constants.MODERATION_ROLES):
await ctx.send("You may not use this command on users other than yourself.")
return
- # Non-moderators may only do this in #bot-commands and can't see hidden infractions.
- if not with_role_check(ctx, *STAFF_ROLES):
- if not ctx.channel.id == Channels.bot:
- raise InChannelCheckFailure(Channels.bot)
- # Hide hidden infractions for users without a moderation role
- hidden = False
+ # Non-staff may only do this in #bot-commands
+ if not with_role_check(ctx, *constants.STAFF_ROLES):
+ if not ctx.channel.id == constants.Channels.bot:
+ raise InChannelCheckFailure(constants.Channels.bot)
+
+ embed = await self.create_user_embed(ctx, user)
- # User information
+ await ctx.send(embed=embed)
+
+ async def create_user_embed(self, ctx: Context, user: Member) -> Embed:
+ """Creates an embed containing information on the `user`."""
created = time_since(user.created_at, max_units=3)
+ # Custom status
+ custom_status = ''
+ for activity in user.activities:
+ if activity.name == 'Custom Status':
+ state = escape_markdown(activity.state)
+ custom_status = f'Status: {state}\n'
+
name = str(user)
if user.nick:
name = f"{user.nick} ({name})"
- # Member information
joined = time_since(user.joined_at, precision="days")
-
- # You're welcome, Volcyyyyyyyyyyyyyyyy
roles = ", ".join(role.mention for role in user.roles if role.name != "@everyone")
- # Infractions
+ description = [
+ textwrap.dedent(f"""
+ **User Information**
+ Created: {created}
+ Profile: {user.mention}
+ ID: {user.id}
+ {custom_status}
+ **Member Information**
+ Joined: {joined}
+ Roles: {roles or None}
+ """).strip()
+ ]
+
+ # Show more verbose output in moderation channels for infractions and nominations
+ if ctx.channel.id in constants.MODERATION_CHANNELS:
+ description.append(await self.expanded_user_infraction_counts(user))
+ description.append(await self.user_nomination_counts(user))
+ else:
+ description.append(await self.basic_user_infraction_counts(user))
+
+ # Let's build the embed now
+ embed = Embed(
+ title=name,
+ description="\n\n".join(description)
+ )
+
+ embed.set_thumbnail(url=user.avatar_url_as(format="png"))
+ embed.colour = user.top_role.colour if roles else Colour.blurple()
+
+ return embed
+
+ async def basic_user_infraction_counts(self, member: Member) -> str:
+ """Gets the total and active infraction counts for the given `member`."""
infractions = await self.bot.api_client.get(
'bot/infractions',
params={
- 'hidden': str(hidden),
- 'user__id': str(user.id)
+ 'hidden': 'False',
+ 'user__id': str(member.id)
}
)
- infr_total = 0
- infr_active = 0
+ total_infractions = len(infractions)
+ active_infractions = sum(infraction['active'] for infraction in infractions)
- # At least it's readable.
- for infr in infractions:
- if infr["active"]:
- infr_active += 1
+ infraction_output = f"**Infractions**\nTotal: {total_infractions}\nActive: {active_infractions}"
- infr_total += 1
+ return infraction_output
- # Let's build the embed now
- embed = Embed(
- title=name,
- description=textwrap.dedent(f"""
- **User Information**
- Created: {created}
- Profile: {user.mention}
- ID: {user.id}
+ async def expanded_user_infraction_counts(self, member: Member) -> str:
+ """
+ Gets expanded infraction counts for the given `member`.
- **Member Information**
- Joined: {joined}
- Roles: {roles or None}
+ The counts will be split by infraction type and the number of active infractions for each type will indicated
+ in the output as well.
+ """
+ infractions = await self.bot.api_client.get(
+ 'bot/infractions',
+ params={
+ 'user__id': str(member.id)
+ }
+ )
- **Infractions**
- Total: {infr_total}
- Active: {infr_active}
- """)
+ infraction_output = ["**Infractions**"]
+ if not infractions:
+ infraction_output.append("This user has never received an infraction.")
+ else:
+ # Count infractions split by `type` and `active` status for this user
+ infraction_types = set()
+ infraction_counter = defaultdict(int)
+ for infraction in infractions:
+ infraction_type = infraction["type"]
+ infraction_active = 'active' if infraction["active"] else 'inactive'
+
+ infraction_types.add(infraction_type)
+ infraction_counter[f"{infraction_active} {infraction_type}"] += 1
+
+ # Format the output of the infraction counts
+ for infraction_type in sorted(infraction_types):
+ active_count = infraction_counter[f"active {infraction_type}"]
+ total_count = active_count + infraction_counter[f"inactive {infraction_type}"]
+
+ line = f"{infraction_type.capitalize()}s: {total_count}"
+ if active_count:
+ line += f" ({active_count} active)"
+
+ infraction_output.append(line)
+
+ return "\n".join(infraction_output)
+
+ async def user_nomination_counts(self, member: Member) -> str:
+ """Gets the active and historical nomination counts for the given `member`."""
+ nominations = await self.bot.api_client.get(
+ 'bot/nominations',
+ params={
+ 'user__id': str(member.id)
+ }
)
- embed.set_thumbnail(url=user.avatar_url_as(format="png"))
- embed.colour = user.top_role.colour if roles else Colour.blurple()
+ output = ["**Nominations**"]
- await ctx.send(embed=embed)
+ if not nominations:
+ output.append("This user has never been nominated.")
+ else:
+ count = len(nominations)
+ is_currently_nominated = any(nomination["active"] for nomination in nominations)
+ nomination_noun = "nomination" if count == 1 else "nominations"
+
+ if is_currently_nominated:
+ output.append(f"This user is **currently** nominated ({count} {nomination_noun} in total).")
+ else:
+ output.append(f"This user has {count} historical {nomination_noun}, but is currently not nominated.")
+
+ return "\n".join(output)
+
+ def format_fields(self, mapping: Mapping[str, Any], field_width: Optional[int] = None) -> str:
+ """Format a mapping to be readable to a human."""
+ # sorting is technically superfluous but nice if you want to look for a specific field
+ fields = sorted(mapping.items(), key=lambda item: item[0])
+
+ if field_width is None:
+ field_width = len(max(mapping.keys(), key=len))
+
+ out = ''
+
+ for key, val in fields:
+ if isinstance(val, dict):
+ # if we have dicts inside dicts we want to apply the same treatment to the inner dictionaries
+ inner_width = int(field_width * 1.6)
+ val = '\n' + self.format_fields(val, field_width=inner_width)
+
+ elif isinstance(val, str):
+ # split up text since it might be long
+ text = textwrap.fill(val, width=100, replace_whitespace=False)
+
+ # indent it, I guess you could do this with `wrap` and `join` but this is nicer
+ val = textwrap.indent(text, ' ' * (field_width + len(': ')))
+
+ # the first line is already indented so we `str.lstrip` it
+ val = val.lstrip()
+
+ if key == 'color':
+ # makes the base 10 representation of a hex number readable to humans
+ val = hex(val)
+
+ out += '{0:>{width}}: {1}\n'.format(key, val, width=field_width)
+
+ # remove trailing whitespace
+ return out.rstrip()
+
+ @cooldown_with_role_bypass(2, 60 * 3, BucketType.member, bypass_roles=constants.STAFF_ROLES)
+ @group(invoke_without_command=True)
+ @in_channel(constants.Channels.bot, bypass_roles=constants.STAFF_ROLES)
+ async def raw(self, ctx: Context, *, message: discord.Message, json: bool = False) -> None:
+ """Shows information about the raw API response."""
+ # I *guess* it could be deleted right as the command is invoked but I felt like it wasn't worth handling
+ # doing this extra request is also much easier than trying to convert everything back into a dictionary again
+ raw_data = await ctx.bot.http.get_message(message.channel.id, message.id)
+
+ paginator = commands.Paginator()
+
+ def add_content(title: str, content: str) -> None:
+ paginator.add_line(f'== {title} ==\n')
+ # replace backticks as it breaks out of code blocks. Spaces seemed to be the most reasonable solution.
+ # we hope it's not close to 2000
+ paginator.add_line(content.replace('```', '`` `'))
+ paginator.close_page()
+
+ if message.content:
+ add_content('Raw message', message.content)
+
+ transformer = pprint.pformat if json else self.format_fields
+ for field_name in ('embeds', 'attachments'):
+ data = raw_data[field_name]
+
+ if not data:
+ continue
+
+ total = len(data)
+ for current, item in enumerate(data, start=1):
+ title = f'Raw {field_name} ({current}/{total})'
+ add_content(title, transformer(item))
+
+ for page in paginator.pages:
+ await ctx.send(page)
+
+ @raw.command()
+ async def json(self, ctx: Context, message: discord.Message) -> None:
+ """Shows information about the raw API response in a copy-pasteable Python format."""
+ await ctx.invoke(self.raw, message=message, json=True)
def setup(bot: Bot) -> None:
diff --git a/bot/cogs/logging.py b/bot/cogs/logging.py
index 8e47bcc36..c92b619ff 100644
--- a/bot/cogs/logging.py
+++ b/bot/cogs/logging.py
@@ -15,9 +15,11 @@ class Logging(Cog):
def __init__(self, bot: Bot):
self.bot = bot
- @Cog.listener()
- async def on_ready(self) -> None:
+ self.bot.loop.create_task(self.startup_greeting())
+
+ async def startup_greeting(self) -> None:
"""Announce our presence to the configured devlog channel."""
+ await self.bot.wait_until_ready()
log.info("Bot connected!")
embed = Embed(description="Connected!")
diff --git a/bot/cogs/moderation.py b/bot/cogs/moderation.py
deleted file mode 100644
index 5aa873a47..000000000
--- a/bot/cogs/moderation.py
+++ /dev/null
@@ -1,1172 +0,0 @@
-import asyncio
-import logging
-import textwrap
-from datetime import datetime
-from typing import Dict, Union
-
-from discord import (
- Colour, Embed, Forbidden, Guild, HTTPException, Member, NotFound, Object, User
-)
-from discord.ext.commands import (
- BadArgument, BadUnionArgument, Bot, Cog, Context, command, group
-)
-
-from bot import constants
-from bot.cogs.modlog import ModLog
-from bot.constants import Colours, Event, Icons, MODERATION_ROLES
-from bot.converters import Duration, InfractionSearchQuery
-from bot.decorators import with_role
-from bot.pagination import LinePaginator
-from bot.utils.moderation import already_has_active_infraction, post_infraction
-from bot.utils.scheduling import Scheduler, create_task
-from bot.utils.time import INFRACTION_FORMAT, format_infraction, wait_until
-
-log = logging.getLogger(__name__)
-
-INFRACTION_ICONS = {
- "Mute": Icons.user_mute,
- "Kick": Icons.sign_out,
- "Ban": Icons.user_ban
-}
-RULES_URL = "https://pythondiscord.com/pages/rules"
-APPEALABLE_INFRACTIONS = ("Ban", "Mute")
-
-
-def proxy_user(user_id: str) -> Object:
- """Create a proxy user for the provided user_id for situations where a Member or User object cannot be resolved."""
- try:
- user_id = int(user_id)
- except ValueError:
- raise BadArgument
- user = Object(user_id)
- user.mention = user.id
- user.avatar_url_as = lambda static_format: None
- return user
-
-
-def permanent_duration(expires_at: str) -> str:
- """Only allow an expiration to be 'permanent' if it is a string."""
- expires_at = expires_at.lower()
- if expires_at != "permanent":
- raise BadArgument
- else:
- return expires_at
-
-
-UserTypes = Union[Member, User, proxy_user]
-
-
-class Moderation(Scheduler, Cog):
- """Server moderation tools."""
-
- def __init__(self, bot: Bot):
- self.bot = bot
- self._muted_role = Object(constants.Roles.muted)
- super().__init__()
-
- @property
- def mod_log(self) -> ModLog:
- """Get currently loaded ModLog cog instance."""
- return self.bot.get_cog("ModLog")
-
- @Cog.listener()
- async def on_ready(self) -> None:
- """Schedule expiration for previous infractions."""
- # Schedule expiration for previous infractions
- infractions = await self.bot.api_client.get(
- 'bot/infractions', params={'active': 'true'}
- )
- for infraction in infractions:
- if infraction["expires_at"] is not None:
- self.schedule_task(self.bot.loop, infraction["id"], infraction)
-
- @Cog.listener()
- async def on_member_join(self, member: Member) -> None:
- """Reapply active mute infractions for returning members."""
- active_mutes = await self.bot.api_client.get(
- 'bot/infractions',
- params={'user__id': str(member.id), 'type': 'mute', 'active': 'true'}
- )
- if not active_mutes:
- return
-
- # assume a single mute because of restrictions elsewhere
- mute = active_mutes[0]
-
- # transform expiration to delay in seconds
- expiration_datetime = datetime.fromisoformat(mute["expires_at"][:-1])
- delay = expiration_datetime - datetime.utcnow()
- delay_seconds = delay.total_seconds()
-
- # if under a minute or in the past
- if delay_seconds < 60:
- log.debug(f"Marking infraction {mute['id']} as inactive (expired).")
- await self._deactivate_infraction(mute)
- self.cancel_task(mute["id"])
-
- # Notify the user that they've been unmuted.
- await self.notify_pardon(
- user=member,
- title="You have been unmuted.",
- content="You may now send messages in the server.",
- icon_url=Icons.user_unmute
- )
- return
-
- # allowing modlog since this is a passive action that should be logged
- await member.add_roles(self._muted_role, reason=f"Re-applying active mute: {mute['id']}")
- log.debug(f"User {member.id} has been re-muted on rejoin.")
-
- # region: Permanent infractions
-
- @with_role(*MODERATION_ROLES)
- @command()
- async def warn(self, ctx: Context, user: UserTypes, *, reason: str = None) -> None:
- """Create a warning infraction in the database for a user."""
- infraction = await post_infraction(ctx, user, type="warning", reason=reason)
- if infraction is None:
- return
-
- notified = await self.notify_infraction(user=user, infr_type="Warning", reason=reason)
-
- dm_result = ":incoming_envelope: " if notified else ""
- action = f"{dm_result}:ok_hand: warned {user.mention}"
- await ctx.send(f"{action}.")
-
- if notified:
- dm_status = "Sent"
- log_content = None
- else:
- dm_status = "**Failed**"
- log_content = ctx.author.mention
-
- await self.mod_log.send_log_message(
- icon_url=Icons.user_warn,
- colour=Colour(Colours.soft_red),
- title="Member warned",
- thumbnail=user.avatar_url_as(static_format="png"),
- text=textwrap.dedent(f"""
- Member: {user.mention} (`{user.id}`)
- Actor: {ctx.author}
- DM: {dm_status}
- Reason: {reason}
- """),
- content=log_content,
- footer=f"ID {infraction['id']}"
- )
-
- @with_role(*MODERATION_ROLES)
- @command()
- async def kick(self, ctx: Context, user: Member, *, reason: str = None) -> None:
- """Kicks a user with the provided reason."""
- if not await self.respect_role_hierarchy(ctx, user, 'kick'):
- # Ensure ctx author has a higher top role than the target user
- # Warning is sent to ctx by the helper method
- return
-
- infraction = await post_infraction(ctx, user, type="kick", reason=reason)
- if infraction is None:
- return
-
- notified = await self.notify_infraction(user=user, infr_type="Kick", reason=reason)
-
- self.mod_log.ignore(Event.member_remove, user.id)
-
- try:
- await user.kick(reason=reason)
- action_result = True
- except Forbidden:
- action_result = False
-
- dm_result = ":incoming_envelope: " if notified else ""
- action = f"{dm_result}:ok_hand: kicked {user.mention}"
- await ctx.send(f"{action}.")
-
- dm_status = "Sent" if notified else "**Failed**"
- title = "Member kicked" if action_result else "Member kicked (Failed)"
- log_content = None if all((notified, action_result)) else ctx.author.mention
-
- await self.mod_log.send_log_message(
- icon_url=Icons.sign_out,
- colour=Colour(Colours.soft_red),
- title=title,
- thumbnail=user.avatar_url_as(static_format="png"),
- text=textwrap.dedent(f"""
- Member: {user.mention} (`{user.id}`)
- Actor: {ctx.message.author}
- DM: {dm_status}
- Reason: {reason}
- """),
- content=log_content,
- footer=f"ID {infraction['id']}"
- )
-
- @with_role(*MODERATION_ROLES)
- @command()
- async def ban(self, ctx: Context, user: UserTypes, *, reason: str = None) -> None:
- """Create a permanent ban infraction for a user with the provided reason."""
- if not await self.respect_role_hierarchy(ctx, user, 'ban'):
- # Ensure ctx author has a higher top role than the target user
- # Warning is sent to ctx by the helper method
- return
-
- if await already_has_active_infraction(ctx=ctx, user=user, type="ban"):
- return
-
- infraction = await post_infraction(ctx, user, type="ban", reason=reason)
- if infraction is None:
- return
-
- notified = await self.notify_infraction(
- user=user,
- infr_type="Ban",
- reason=reason
- )
-
- self.mod_log.ignore(Event.member_ban, user.id)
- self.mod_log.ignore(Event.member_remove, user.id)
-
- try:
- await ctx.guild.ban(user, reason=reason, delete_message_days=0)
- action_result = True
- except Forbidden:
- action_result = False
-
- dm_result = ":incoming_envelope: " if notified else ""
- action = f"{dm_result}:ok_hand: permanently banned {user.mention}"
- await ctx.send(f"{action}.")
-
- dm_status = "Sent" if notified else "**Failed**"
- log_content = None if all((notified, action_result)) else ctx.author.mention
- title = "Member permanently banned"
- if not action_result:
- title += " (Failed)"
-
- await self.mod_log.send_log_message(
- icon_url=Icons.user_ban,
- colour=Colour(Colours.soft_red),
- title=title,
- thumbnail=user.avatar_url_as(static_format="png"),
- text=textwrap.dedent(f"""
- Member: {user.mention} (`{user.id}`)
- Actor: {ctx.message.author}
- DM: {dm_status}
- Reason: {reason}
- """),
- content=log_content,
- footer=f"ID {infraction['id']}"
- )
-
- # endregion
- # region: Temporary infractions
-
- @with_role(*MODERATION_ROLES)
- @command(aliases=('mute',))
- async def tempmute(self, ctx: Context, user: Member, duration: Duration, *, reason: str = None) -> None:
- """
- Create a temporary mute infraction for a user with the provided expiration and reason.
-
- Duration strings are parsed per: http://strftime.org/
- """
- expiration = duration
-
- if await already_has_active_infraction(ctx=ctx, user=user, type="mute"):
- return
-
- infraction = await post_infraction(ctx, user, type="mute", reason=reason, expires_at=expiration)
- if infraction is None:
- return
-
- self.mod_log.ignore(Event.member_update, user.id)
- await user.add_roles(self._muted_role, reason=reason)
-
- notified = await self.notify_infraction(
- user=user,
- infr_type="Mute",
- expires_at=expiration,
- reason=reason
- )
-
- infraction_expiration = format_infraction(infraction["expires_at"])
-
- self.schedule_task(ctx.bot.loop, infraction["id"], infraction)
-
- dm_result = ":incoming_envelope: " if notified else ""
- action = f"{dm_result}:ok_hand: muted {user.mention} until {infraction_expiration}"
- await ctx.send(f"{action}.")
-
- if notified:
- dm_status = "Sent"
- log_content = None
- else:
- dm_status = "**Failed**"
- log_content = ctx.author.mention
-
- await self.mod_log.send_log_message(
- icon_url=Icons.user_mute,
- colour=Colour(Colours.soft_red),
- title="Member temporarily muted",
- thumbnail=user.avatar_url_as(static_format="png"),
- text=textwrap.dedent(f"""
- Member: {user.mention} (`{user.id}`)
- Actor: {ctx.message.author}
- DM: {dm_status}
- Reason: {reason}
- Expires: {infraction_expiration}
- """),
- content=log_content,
- footer=f"ID {infraction['id']}"
- )
-
- @with_role(*MODERATION_ROLES)
- @command()
- async def tempban(self, ctx: Context, user: UserTypes, duration: Duration, *, reason: str = None) -> None:
- """
- Create a temporary ban infraction for a user with the provided expiration and reason.
-
- Duration strings are parsed per: http://strftime.org/
- """
- expiration = duration
-
- if not await self.respect_role_hierarchy(ctx, user, 'tempban'):
- # Ensure ctx author has a higher top role than the target user
- # Warning is sent to ctx by the helper method
- return
-
- if await already_has_active_infraction(ctx=ctx, user=user, type="ban"):
- return
-
- infraction = await post_infraction(ctx, user, type="ban", reason=reason, expires_at=expiration)
- if infraction is None:
- return
-
- notified = await self.notify_infraction(
- user=user,
- infr_type="Ban",
- expires_at=expiration,
- reason=reason
- )
-
- self.mod_log.ignore(Event.member_ban, user.id)
- self.mod_log.ignore(Event.member_remove, user.id)
-
- try:
- await ctx.guild.ban(user, reason=reason, delete_message_days=0)
- action_result = True
- except Forbidden:
- action_result = False
-
- infraction_expiration = format_infraction(infraction["expires_at"])
-
- self.schedule_task(ctx.bot.loop, infraction["id"], infraction)
-
- dm_result = ":incoming_envelope: " if notified else ""
- action = f"{dm_result}:ok_hand: banned {user.mention} until {infraction_expiration}"
- await ctx.send(f"{action}.")
-
- dm_status = "Sent" if notified else "**Failed**"
- log_content = None if all((notified, action_result)) else ctx.author.mention
- title = "Member temporarily banned"
- if not action_result:
- title += " (Failed)"
-
- await self.mod_log.send_log_message(
- icon_url=Icons.user_ban,
- colour=Colour(Colours.soft_red),
- thumbnail=user.avatar_url_as(static_format="png"),
- title=title,
- text=textwrap.dedent(f"""
- Member: {user.mention} (`{user.id}`)
- Actor: {ctx.message.author}
- DM: {dm_status}
- Reason: {reason}
- Expires: {infraction_expiration}
- """),
- content=log_content,
- footer=f"ID {infraction['id']}"
- )
-
- # endregion
- # region: Permanent shadow infractions
-
- @with_role(*MODERATION_ROLES)
- @command(hidden=True)
- async def note(self, ctx: Context, user: UserTypes, *, reason: str = None) -> None:
- """
- Create a private infraction note in the database for a user with the provided reason.
-
- This does not send the user a notification
- """
- infraction = await post_infraction(ctx, user, type="note", reason=reason, hidden=True)
- if infraction is None:
- return
-
- await ctx.send(f":ok_hand: note added for {user.mention}.")
-
- await self.mod_log.send_log_message(
- icon_url=Icons.user_warn,
- colour=Colour(Colours.soft_red),
- title="Member note added",
- thumbnail=user.avatar_url_as(static_format="png"),
- text=textwrap.dedent(f"""
- Member: {user.mention} (`{user.id}`)
- Actor: {ctx.message.author}
- Reason: {reason}
- """),
- footer=f"ID {infraction['id']}"
- )
-
- @with_role(*MODERATION_ROLES)
- @command(hidden=True, aliases=['shadowkick', 'skick'])
- async def shadow_kick(self, ctx: Context, user: Member, *, reason: str = None) -> None:
- """
- Kick a user for the provided reason.
-
- This does not send the user a notification.
- """
- if not await self.respect_role_hierarchy(ctx, user, 'shadowkick'):
- # Ensure ctx author has a higher top role than the target user
- # Warning is sent to ctx by the helper method
- return
-
- infraction = await post_infraction(ctx, user, type="kick", reason=reason, hidden=True)
- if infraction is None:
- return
-
- self.mod_log.ignore(Event.member_remove, user.id)
-
- try:
- await user.kick(reason=reason)
- action_result = True
- except Forbidden:
- action_result = False
-
- await ctx.send(f":ok_hand: kicked {user.mention}.")
-
- title = "Member shadow kicked"
- if action_result:
- log_content = None
- else:
- log_content = ctx.author.mention
- title += " (Failed)"
-
- await self.mod_log.send_log_message(
- icon_url=Icons.sign_out,
- colour=Colour(Colours.soft_red),
- title=title,
- thumbnail=user.avatar_url_as(static_format="png"),
- text=textwrap.dedent(f"""
- Member: {user.mention} (`{user.id}`)
- Actor: {ctx.message.author}
- Reason: {reason}
- """),
- content=log_content,
- footer=f"ID {infraction['id']}"
- )
-
- @with_role(*MODERATION_ROLES)
- @command(hidden=True, aliases=['shadowban', 'sban'])
- async def shadow_ban(self, ctx: Context, user: UserTypes, *, reason: str = None) -> None:
- """
- Create a permanent ban infraction for a user with the provided reason.
-
- This does not send the user a notification.
- """
- if not await self.respect_role_hierarchy(ctx, user, 'shadowban'):
- # Ensure ctx author has a higher top role than the target user
- # Warning is sent to ctx by the helper method
- return
-
- if await already_has_active_infraction(ctx=ctx, user=user, type="ban"):
- return
-
- infraction = await post_infraction(ctx, user, type="ban", reason=reason, hidden=True)
- if infraction is None:
- return
-
- self.mod_log.ignore(Event.member_ban, user.id)
- self.mod_log.ignore(Event.member_remove, user.id)
-
- try:
- await ctx.guild.ban(user, reason=reason, delete_message_days=0)
- action_result = True
- except Forbidden:
- action_result = False
-
- await ctx.send(f":ok_hand: permanently banned {user.mention}.")
-
- title = "Member permanently banned"
- if action_result:
- log_content = None
- else:
- log_content = ctx.author.mention
- title += " (Failed)"
-
- await self.mod_log.send_log_message(
- icon_url=Icons.user_ban,
- colour=Colour(Colours.soft_red),
- title=title,
- thumbnail=user.avatar_url_as(static_format="png"),
- text=textwrap.dedent(f"""
- Member: {user.mention} (`{user.id}`)
- Actor: {ctx.message.author}
- Reason: {reason}
- """),
- content=log_content,
- footer=f"ID {infraction['id']}"
- )
-
- # endregion
- # region: Temporary shadow infractions
-
- @with_role(*MODERATION_ROLES)
- @command(hidden=True, aliases=["shadowtempmute, stempmute", "shadowmute", "smute"])
- async def shadow_tempmute(
- self, ctx: Context, user: Member, duration: Duration, *, reason: str = None
- ) -> None:
- """
- Create a temporary mute infraction for a user with the provided reason.
-
- Duration strings are parsed per: http://strftime.org/
-
- This does not send the user a notification.
- """
- expiration = duration
-
- if await already_has_active_infraction(ctx=ctx, user=user, type="mute"):
- return
-
- infraction = await post_infraction(ctx, user, type="mute", reason=reason, expires_at=expiration, hidden=True)
- if infraction is None:
- return
-
- self.mod_log.ignore(Event.member_update, user.id)
- await user.add_roles(self._muted_role, reason=reason)
-
- infraction_expiration = format_infraction(infraction["expires_at"])
- self.schedule_task(ctx.bot.loop, infraction["id"], infraction)
- await ctx.send(f":ok_hand: muted {user.mention} until {infraction_expiration}.")
-
- await self.mod_log.send_log_message(
- icon_url=Icons.user_mute,
- colour=Colour(Colours.soft_red),
- title="Member temporarily muted",
- thumbnail=user.avatar_url_as(static_format="png"),
- text=textwrap.dedent(f"""
- Member: {user.mention} (`{user.id}`)
- Actor: {ctx.message.author}
- Reason: {reason}
- Expires: {infraction_expiration}
- """),
- footer=f"ID {infraction['id']}"
- )
-
- @with_role(*MODERATION_ROLES)
- @command(hidden=True, aliases=["shadowtempban, stempban"])
- async def shadow_tempban(
- self, ctx: Context, user: UserTypes, duration: Duration, *, reason: str = None
- ) -> None:
- """
- Create a temporary ban infraction for a user with the provided reason.
-
- Duration strings are parsed per: http://strftime.org/
-
- This does not send the user a notification.
- """
- expiration = duration
-
- if not await self.respect_role_hierarchy(ctx, user, 'shadowtempban'):
- # Ensure ctx author has a higher top role than the target user
- # Warning is sent to ctx by the helper method
- return
-
- if await already_has_active_infraction(ctx=ctx, user=user, type="ban"):
- return
-
- infraction = await post_infraction(ctx, user, type="ban", reason=reason, expires_at=expiration, hidden=True)
- if infraction is None:
- return
-
- self.mod_log.ignore(Event.member_ban, user.id)
- self.mod_log.ignore(Event.member_remove, user.id)
-
- try:
- await ctx.guild.ban(user, reason=reason, delete_message_days=0)
- action_result = True
- except Forbidden:
- action_result = False
-
- infraction_expiration = format_infraction(infraction["expires_at"])
- self.schedule_task(ctx.bot.loop, infraction["id"], infraction)
- await ctx.send(f":ok_hand: banned {user.mention} until {infraction_expiration}.")
-
- title = "Member temporarily banned"
- if action_result:
- log_content = None
- else:
- log_content = ctx.author.mention
- title += " (Failed)"
-
- # Send a log message to the mod log
- await self.mod_log.send_log_message(
- icon_url=Icons.user_ban,
- colour=Colour(Colours.soft_red),
- thumbnail=user.avatar_url_as(static_format="png"),
- title=title,
- text=textwrap.dedent(f"""
- Member: {user.mention} (`{user.id}`)
- Actor: {ctx.message.author}
- Reason: {reason}
- Expires: {infraction_expiration}
- """),
- content=log_content,
- footer=f"ID {infraction['id']}"
- )
-
- # endregion
- # region: Remove infractions (un- commands)
-
- @with_role(*MODERATION_ROLES)
- @command()
- async def unmute(self, ctx: Context, user: UserTypes) -> None:
- """Deactivates the active mute infraction for a user."""
- try:
- # check the current active infraction
- response = await self.bot.api_client.get(
- 'bot/infractions',
- params={
- 'active': 'true',
- 'type': 'mute',
- 'user__id': user.id
- }
- )
- if len(response) > 1:
- log.warning("Found more than one active mute infraction for user `%d`", user.id)
-
- if not response:
- # no active infraction
- await ctx.send(
- f":x: There is no active mute infraction for user {user.mention}."
- )
- return
-
- for infraction in response:
- await self._deactivate_infraction(infraction)
- if infraction["expires_at"] is not None:
- self.cancel_expiration(infraction["id"])
-
- notified = await self.notify_pardon(
- user=user,
- title="You have been unmuted.",
- content="You may now send messages in the server.",
- icon_url=Icons.user_unmute
- )
-
- if notified:
- dm_status = "Sent"
- dm_emoji = ":incoming_envelope: "
- log_content = None
- else:
- dm_status = "**Failed**"
- dm_emoji = ""
- log_content = ctx.author.mention
-
- await ctx.send(f"{dm_emoji}:ok_hand: Un-muted {user.mention}.")
-
- embed_text = textwrap.dedent(
- f"""
- Member: {user.mention} (`{user.id}`)
- Actor: {ctx.message.author}
- DM: {dm_status}
- """
- )
-
- if len(response) > 1:
- footer = f"Infraction IDs: {', '.join(str(infr['id']) for infr in response)}"
- title = "Member unmuted"
- embed_text += "Note: User had multiple **active** mute infractions in the database."
- else:
- infraction = response[0]
- footer = f"Infraction ID: {infraction['id']}"
- title = "Member unmuted"
-
- # Send a log message to the mod log
- await self.mod_log.send_log_message(
- icon_url=Icons.user_unmute,
- colour=Colour(Colours.soft_green),
- title=title,
- thumbnail=user.avatar_url_as(static_format="png"),
- text=embed_text,
- footer=footer,
- content=log_content
- )
- except Exception:
- log.exception("There was an error removing an infraction.")
- await ctx.send(":x: There was an error removing the infraction.")
-
- @with_role(*MODERATION_ROLES)
- @command()
- async def unban(self, ctx: Context, user: UserTypes) -> None:
- """Deactivates the active ban infraction for a user."""
- try:
- # check the current active infraction
- response = await self.bot.api_client.get(
- 'bot/infractions',
- params={
- 'active': 'true',
- 'type': 'ban',
- 'user__id': str(user.id)
- }
- )
- if len(response) > 1:
- log.warning(
- "More than one active ban infraction found for user `%d`.",
- user.id
- )
-
- if not response:
- # no active infraction
- await ctx.send(
- f":x: There is no active ban infraction for user {user.mention}."
- )
- return
-
- for infraction in response:
- await self._deactivate_infraction(infraction)
- if infraction["expires_at"] is not None:
- self.cancel_expiration(infraction["id"])
-
- embed_text = textwrap.dedent(
- f"""
- Member: {user.mention} (`{user.id}`)
- Actor: {ctx.message.author}
- """
- )
-
- if len(response) > 1:
- footer = f"Infraction IDs: {', '.join(str(infr['id']) for infr in response)}"
- embed_text += "Note: User had multiple **active** ban infractions in the database."
- else:
- infraction = response[0]
- footer = f"Infraction ID: {infraction['id']}"
-
- await ctx.send(f":ok_hand: Un-banned {user.mention}.")
-
- # Send a log message to the mod log
- await self.mod_log.send_log_message(
- icon_url=Icons.user_unban,
- colour=Colour(Colours.soft_green),
- title="Member unbanned",
- thumbnail=user.avatar_url_as(static_format="png"),
- text=embed_text,
- footer=footer,
- )
- except Exception:
- log.exception("There was an error removing an infraction.")
- await ctx.send(":x: There was an error removing the infraction.")
-
- # endregion
- # region: Edit infraction commands
-
- @with_role(*MODERATION_ROLES)
- @group(name='infraction', aliases=('infr', 'infractions', 'inf'), invoke_without_command=True)
- async def infraction_group(self, ctx: Context) -> None:
- """Infraction manipulation commands."""
- await ctx.invoke(self.bot.get_command("help"), "infraction")
-
- @with_role(*MODERATION_ROLES)
- @infraction_group.command(name='edit')
- async def infraction_edit(
- self,
- ctx: Context,
- infraction_id: int,
- expires_at: Union[Duration, permanent_duration, None],
- *,
- reason: str = None
- ) -> None:
- """
- Edit the duration and/or the reason of an infraction.
-
- Durations are relative to the time of updating.
- Use "permanent" to mark the infraction as permanent.
- """
- if expires_at is None and reason is None:
- # Unlike UserInputError, the error handler will show a specified message for BadArgument
- raise BadArgument("Neither a new expiry nor a new reason was specified.")
-
- # Retrieve the previous infraction for its information.
- old_infraction = await self.bot.api_client.get(f'bot/infractions/{infraction_id}')
-
- request_data = {}
- confirm_messages = []
- log_text = ""
-
- if expires_at == "permanent":
- request_data['expires_at'] = None
- confirm_messages.append("marked as permanent")
- elif expires_at is not None:
- request_data['expires_at'] = expires_at.isoformat()
- confirm_messages.append(f"set to expire on {expires_at.strftime(INFRACTION_FORMAT)}")
- else:
- confirm_messages.append("expiry unchanged")
-
- if reason:
- request_data['reason'] = reason
- confirm_messages.append("set a new reason")
- log_text += f"""
- Previous reason: {old_infraction['reason']}
- New reason: {reason}
- """.rstrip()
- else:
- confirm_messages.append("reason unchanged")
-
- # Update the infraction
- new_infraction = await self.bot.api_client.patch(
- f'bot/infractions/{infraction_id}',
- json=request_data,
- )
-
- # Re-schedule infraction if the expiration has been updated
- if 'expires_at' in request_data:
- self.cancel_task(new_infraction['id'])
- loop = asyncio.get_event_loop()
- self.schedule_task(loop, new_infraction['id'], new_infraction)
-
- log_text += f"""
- Previous expiry: {old_infraction['expires_at'] or "Permanent"}
- New expiry: {new_infraction['expires_at'] or "Permanent"}
- """.rstrip()
-
- await ctx.send(f":ok_hand: Updated infraction: {' & '.join(confirm_messages)}")
-
- # Get information about the infraction's user
- user_id = new_infraction['user']
- user = ctx.guild.get_member(user_id)
-
- if user:
- user_text = f"{user.mention} (`{user.id}`)"
- thumbnail = user.avatar_url_as(static_format="png")
- else:
- user_text = f"`{user_id}`"
- thumbnail = None
-
- # The infraction's actor
- actor_id = new_infraction['actor']
- actor = ctx.guild.get_member(actor_id) or f"`{actor_id}`"
-
- await self.mod_log.send_log_message(
- icon_url=Icons.pencil,
- colour=Colour.blurple(),
- title="Infraction edited",
- thumbnail=thumbnail,
- text=textwrap.dedent(f"""
- Member: {user_text}
- Actor: {actor}
- Edited by: {ctx.message.author}{log_text}
- """)
- )
-
- # endregion
- # region: Search infractions
-
- @with_role(*MODERATION_ROLES)
- @infraction_group.group(name="search", invoke_without_command=True)
- async def infraction_search_group(self, ctx: Context, query: InfractionSearchQuery) -> None:
- """Searches for infractions in the database."""
- if isinstance(query, User):
- await ctx.invoke(self.search_user, query)
-
- else:
- await ctx.invoke(self.search_reason, query)
-
- @with_role(*MODERATION_ROLES)
- @infraction_search_group.command(name="user", aliases=("member", "id"))
- async def search_user(self, ctx: Context, user: Union[User, proxy_user]) -> None:
- """Search for infractions by member."""
- infraction_list = await self.bot.api_client.get(
- 'bot/infractions',
- params={'user__id': str(user.id)}
- )
- embed = Embed(
- title=f"Infractions for {user} ({len(infraction_list)} total)",
- colour=Colour.orange()
- )
- await self.send_infraction_list(ctx, embed, infraction_list)
-
- @with_role(*MODERATION_ROLES)
- @infraction_search_group.command(name="reason", aliases=("match", "regex", "re"))
- async def search_reason(self, ctx: Context, reason: str) -> None:
- """Search for infractions by their reason. Use Re2 for matching."""
- infraction_list = await self.bot.api_client.get(
- 'bot/infractions', params={'search': reason}
- )
- embed = Embed(
- title=f"Infractions matching `{reason}` ({len(infraction_list)} total)",
- colour=Colour.orange()
- )
- await self.send_infraction_list(ctx, embed, infraction_list)
-
- # endregion
- # region: Utility functions
-
- async def send_infraction_list(self, ctx: Context, embed: Embed, infractions: list) -> None:
- """Send a paginated embed of infractions for the specified user."""
- if not infractions:
- await ctx.send(f":warning: No infractions could be found for that query.")
- return
-
- lines = tuple(
- self._infraction_to_string(infraction)
- for infraction in infractions
- )
-
- await LinePaginator.paginate(
- lines,
- ctx=ctx,
- embed=embed,
- empty=True,
- max_lines=3,
- max_size=1000
- )
-
- # endregion
- # region: Utility functions
-
- def schedule_expiration(
- self, loop: asyncio.AbstractEventLoop, infraction_object: Dict[str, Union[str, int, bool]]
- ) -> None:
- """Schedules a task to expire a temporary infraction."""
- infraction_id = infraction_object["id"]
- if infraction_id in self.scheduled_tasks:
- return
-
- task: asyncio.Task = create_task(loop, self._scheduled_expiration(infraction_object))
-
- self.scheduled_tasks[infraction_id] = task
-
- def cancel_expiration(self, infraction_id: str) -> None:
- """Un-schedules a task set to expire a temporary infraction."""
- task = self.scheduled_tasks.get(infraction_id)
- if task is None:
- log.warning(f"Failed to unschedule {infraction_id}: no task found.")
- return
- task.cancel()
- log.debug(f"Unscheduled {infraction_id}.")
- del self.scheduled_tasks[infraction_id]
-
- async def _scheduled_task(self, infraction_object: Dict[str, Union[str, int, bool]]) -> None:
- """
- Marks an infraction expired after the delay from time of scheduling to time of expiration.
-
- At the time of expiration, the infraction is marked as inactive on the website, and the
- expiration task is cancelled. The user is then notified via DM.
- """
- infraction_id = infraction_object["id"]
-
- # transform expiration to delay in seconds
- expiration_datetime = datetime.fromisoformat(infraction_object["expires_at"][:-1])
- await wait_until(expiration_datetime)
-
- log.debug(f"Marking infraction {infraction_id} as inactive (expired).")
- await self._deactivate_infraction(infraction_object)
-
- self.cancel_task(infraction_object["id"])
-
- # Notify the user that they've been unmuted.
- user_id = infraction_object["user"]
- guild = self.bot.get_guild(constants.Guild.id)
- await self.notify_pardon(
- user=guild.get_member(user_id),
- title="You have been unmuted.",
- content="You may now send messages in the server.",
- icon_url=Icons.user_unmute
- )
-
- async def _deactivate_infraction(self, infraction_object: Dict[str, Union[str, int, bool]]) -> None:
- """
- A co-routine which marks an infraction as inactive on the website.
-
- This co-routine does not cancel or un-schedule an expiration task.
- """
- guild: Guild = self.bot.get_guild(constants.Guild.id)
- user_id = infraction_object["user"]
- infraction_type = infraction_object["type"]
-
- await self.bot.api_client.patch(
- 'bot/infractions/' + str(infraction_object['id']),
- json={"active": False}
- )
-
- if infraction_type == "mute":
- member: Member = guild.get_member(user_id)
- if member:
- # remove the mute role
- self.mod_log.ignore(Event.member_update, member.id)
- await member.remove_roles(self._muted_role)
- else:
- log.warning(f"Failed to un-mute user: {user_id} (not found)")
- elif infraction_type == "ban":
- user: Object = Object(user_id)
- try:
- await guild.unban(user)
- except NotFound:
- log.info(f"Tried to unban user `{user_id}`, but Discord does not have an active ban registered.")
-
- def _infraction_to_string(self, infraction_object: Dict[str, Union[str, int, bool]]) -> str:
- """Convert the infraction object to a string representation."""
- actor_id = infraction_object["actor"]
- guild: Guild = self.bot.get_guild(constants.Guild.id)
- actor = guild.get_member(actor_id)
- active = infraction_object["active"]
- user_id = infraction_object["user"]
- hidden = infraction_object["hidden"]
- created = format_infraction(infraction_object["inserted_at"])
- if infraction_object["expires_at"] is None:
- expires = "*Permanent*"
- else:
- expires = format_infraction(infraction_object["expires_at"])
-
- lines = textwrap.dedent(f"""
- {"**===============**" if active else "==============="}
- Status: {"__**Active**__" if active else "Inactive"}
- User: {self.bot.get_user(user_id)} (`{user_id}`)
- Type: **{infraction_object["type"]}**
- Shadow: {hidden}
- Reason: {infraction_object["reason"] or "*None*"}
- Created: {created}
- Expires: {expires}
- Actor: {actor.mention if actor else actor_id}
- ID: `{infraction_object["id"]}`
- {"**===============**" if active else "==============="}
- """)
-
- return lines.strip()
-
- async def notify_infraction(
- self,
- user: Union[User, Member],
- infr_type: str,
- expires_at: Union[datetime, str] = 'N/A',
- reason: str = "No reason provided."
- ) -> bool:
- """
- Attempt to notify a user, via DM, of their fresh infraction.
-
- Returns a boolean indicator of whether the DM was successful.
- """
- if isinstance(expires_at, datetime):
- expires_at = expires_at.strftime(INFRACTION_FORMAT)
-
- embed = Embed(
- description=textwrap.dedent(f"""
- **Type:** {infr_type}
- **Expires:** {expires_at}
- **Reason:** {reason}
- """),
- colour=Colour(Colours.soft_red)
- )
-
- icon_url = INFRACTION_ICONS.get(infr_type, Icons.token_removed)
- embed.set_author(name="Infraction Information", icon_url=icon_url, url=RULES_URL)
- embed.title = f"Please review our rules over at {RULES_URL}"
- embed.url = RULES_URL
-
- if infr_type in APPEALABLE_INFRACTIONS:
- embed.set_footer(text="To appeal this infraction, send an e-mail to [email protected]")
-
- return await self.send_private_embed(user, embed)
-
- async def notify_pardon(
- self,
- user: Union[User, Member],
- title: str,
- content: str,
- icon_url: str = Icons.user_verified
- ) -> bool:
- """
- Attempt to notify a user, via DM, of their expired infraction.
-
- Optionally returns a boolean indicator of whether the DM was successful.
- """
- embed = Embed(
- description=content,
- colour=Colour(Colours.soft_green)
- )
-
- embed.set_author(name=title, icon_url=icon_url)
-
- return await self.send_private_embed(user, embed)
-
- async def send_private_embed(self, user: Union[User, Member], embed: Embed) -> bool:
- """
- A helper method for sending an embed to a user's DMs.
-
- Returns a boolean indicator of DM success.
- """
- # sometimes `user` is a `discord.Object`, so let's make it a proper user.
- user = await self.bot.fetch_user(user.id)
-
- try:
- await user.send(embed=embed)
- return True
- except (HTTPException, Forbidden):
- log.debug(
- f"Infraction-related information could not be sent to user {user} ({user.id}). "
- "They've probably just disabled private messages."
- )
- return False
-
- async def log_notify_failure(self, target: str, actor: Member, infraction_type: str) -> None:
- """Send a mod log entry if an attempt to DM the target user has failed."""
- await self.mod_log.send_log_message(
- icon_url=Icons.token_removed,
- content=actor.mention,
- colour=Colour(Colours.soft_red),
- title="Notification Failed",
- text=(
- f"Direct message was unable to be sent.\nUser: {target.mention}\n"
- f"Type: {infraction_type}"
- )
- )
-
- # endregion
-
- # This cannot be static (must have a __func__ attribute).
- async def cog_command_error(self, ctx: Context, error: Exception) -> None:
- """Send a notification to the invoking context on a Union failure."""
- if isinstance(error, BadUnionArgument):
- if User in error.converters:
- await ctx.send(str(error.errors[0]))
- error.handled = True
-
- @staticmethod
- async def respect_role_hierarchy(ctx: Context, target: UserTypes, infr_type: str) -> bool:
- """
- Check if the highest role of the invoking member is greater than that of the target member.
-
- If this check fails, a warning is sent to the invoking ctx.
-
- Returns True always if target is not a discord.Member instance.
- """
- if not isinstance(target, Member):
- return True
-
- actor = ctx.author
- target_is_lower = target.top_role < actor.top_role
- if not target_is_lower:
- log.info(
- f"{actor} ({actor.id}) attempted to {infr_type} "
- f"{target} ({target.id}), who has an equal or higher top role."
- )
- await ctx.send(
- f":x: {actor.mention}, you may not {infr_type} "
- "someone with an equal or higher top role."
- )
-
- return target_is_lower
-
-
-def setup(bot: Bot) -> None:
- """Moderation cog load."""
- bot.add_cog(Moderation(bot))
- log.info("Cog loaded: Moderation")
diff --git a/bot/cogs/moderation/__init__.py b/bot/cogs/moderation/__init__.py
new file mode 100644
index 000000000..7383ed44e
--- /dev/null
+++ b/bot/cogs/moderation/__init__.py
@@ -0,0 +1,25 @@
+import logging
+
+from discord.ext.commands import Bot
+
+from .infractions import Infractions
+from .management import ModManagement
+from .modlog import ModLog
+from .superstarify import Superstarify
+
+log = logging.getLogger(__name__)
+
+
+def setup(bot: Bot) -> None:
+ """Load the moderation extension (Infractions, ModManagement, ModLog, & Superstarify cogs)."""
+ bot.add_cog(Infractions(bot))
+ log.info("Cog loaded: Infractions")
+
+ bot.add_cog(ModLog(bot))
+ log.info("Cog loaded: ModLog")
+
+ bot.add_cog(ModManagement(bot))
+ log.info("Cog loaded: ModManagement")
+
+ bot.add_cog(Superstarify(bot))
+ log.info("Cog loaded: Superstarify")
diff --git a/bot/cogs/moderation/infractions.py b/bot/cogs/moderation/infractions.py
new file mode 100644
index 000000000..997ffe524
--- /dev/null
+++ b/bot/cogs/moderation/infractions.py
@@ -0,0 +1,617 @@
+import logging
+import textwrap
+import typing as t
+from datetime import datetime
+from gettext import ngettext
+
+import dateutil.parser
+import discord
+from discord import Member
+from discord.ext import commands
+from discord.ext.commands import Context, command
+
+from bot import constants
+from bot.api import ResponseCodeError
+from bot.constants import Colours, Event, STAFF_CHANNELS
+from bot.decorators import respect_role_hierarchy
+from bot.utils import time
+from bot.utils.checks import with_role_check
+from bot.utils.scheduling import Scheduler
+from . import utils
+from .modlog import ModLog
+from .utils import MemberObject
+
+log = logging.getLogger(__name__)
+
+MemberConverter = t.Union[utils.UserTypes, utils.proxy_user]
+
+
+class Infractions(Scheduler, commands.Cog):
+ """Apply and pardon infractions on users for moderation purposes."""
+
+ category = "Moderation"
+ category_description = "Server moderation tools."
+
+ def __init__(self, bot: commands.Bot):
+ super().__init__()
+
+ self.bot = bot
+ self.category = "Moderation"
+ self._muted_role = discord.Object(constants.Roles.muted)
+
+ self.bot.loop.create_task(self.reschedule_infractions())
+
+ @property
+ def mod_log(self) -> ModLog:
+ """Get currently loaded ModLog cog instance."""
+ return self.bot.get_cog("ModLog")
+
+ async def reschedule_infractions(self) -> None:
+ """Schedule expiration for previous infractions."""
+ await self.bot.wait_until_ready()
+
+ infractions = await self.bot.api_client.get(
+ 'bot/infractions',
+ params={'active': 'true'}
+ )
+ for infraction in infractions:
+ if infraction["expires_at"] is not None:
+ self.schedule_task(self.bot.loop, infraction["id"], infraction)
+
+ @commands.Cog.listener()
+ async def on_member_join(self, member: Member) -> None:
+ """Reapply active mute infractions for returning members."""
+ active_mutes = await self.bot.api_client.get(
+ 'bot/infractions',
+ params={
+ 'user__id': str(member.id),
+ 'type': 'mute',
+ 'active': 'true'
+ }
+ )
+ if not active_mutes:
+ return
+
+ # Assume a single mute because of restrictions elsewhere.
+ mute = active_mutes[0]
+
+ # Calculate the time remaining, in seconds, for the mute.
+ expiry = dateutil.parser.isoparse(mute["expires_at"]).replace(tzinfo=None)
+ delta = (expiry - datetime.utcnow()).total_seconds()
+
+ # Mark as inactive if less than a minute remains.
+ if delta < 60:
+ await self.deactivate_infraction(mute)
+ return
+
+ # Allowing mod log since this is a passive action that should be logged.
+ await member.add_roles(self._muted_role, reason=f"Re-applying active mute: {mute['id']}")
+ log.debug(f"User {member.id} has been re-muted on rejoin.")
+
+ # region: Permanent infractions
+
+ @command()
+ async def warn(self, ctx: Context, user: Member, *, reason: str = None) -> None:
+ """Warn a user for the given reason."""
+ infraction = await utils.post_infraction(ctx, user, "warning", reason, active=False)
+ if infraction is None:
+ return
+
+ await self.apply_infraction(ctx, infraction, user)
+
+ @command()
+ async def kick(self, ctx: Context, user: Member, *, reason: str = None) -> None:
+ """Kick a user for the given reason."""
+ await self.apply_kick(ctx, user, reason, active=False)
+
+ @command()
+ async def ban(self, ctx: Context, user: MemberConverter, *, reason: str = None) -> None:
+ """Permanently ban a user for the given reason."""
+ await self.apply_ban(ctx, user, reason)
+
+ # endregion
+ # region: Temporary infractions
+
+ @command(aliases=["mute"])
+ async def tempmute(self, ctx: Context, user: Member, duration: utils.Expiry, *, reason: str = None) -> None:
+ """
+ Temporarily mute a user for the given reason and duration.
+
+ A unit of time should be appended to the duration.
+ Units (∗case-sensitive):
+ \u2003`y` - years
+ \u2003`m` - months∗
+ \u2003`w` - weeks
+ \u2003`d` - days
+ \u2003`h` - hours
+ \u2003`M` - minutes∗
+ \u2003`s` - seconds
+
+ Alternatively, an ISO 8601 timestamp can be provided for the duration.
+ """
+ await self.apply_mute(ctx, user, reason, expires_at=duration)
+
+ @command()
+ async def tempban(self, ctx: Context, user: MemberConverter, duration: utils.Expiry, *, reason: str = None) -> None:
+ """
+ Temporarily ban a user for the given reason and duration.
+
+ A unit of time should be appended to the duration.
+ Units (∗case-sensitive):
+ \u2003`y` - years
+ \u2003`m` - months∗
+ \u2003`w` - weeks
+ \u2003`d` - days
+ \u2003`h` - hours
+ \u2003`M` - minutes∗
+ \u2003`s` - seconds
+
+ Alternatively, an ISO 8601 timestamp can be provided for the duration.
+ """
+ await self.apply_ban(ctx, user, reason, expires_at=duration)
+
+ # endregion
+ # region: Permanent shadow infractions
+
+ @command(hidden=True)
+ async def note(self, ctx: Context, user: MemberConverter, *, reason: str = None) -> None:
+ """Create a private note for a user with the given reason without notifying the user."""
+ infraction = await utils.post_infraction(ctx, user, "note", reason, hidden=True, active=False)
+ if infraction is None:
+ return
+
+ await self.apply_infraction(ctx, infraction, user)
+
+ @command(hidden=True, aliases=['shadowkick', 'skick'])
+ async def shadow_kick(self, ctx: Context, user: Member, *, reason: str = None) -> None:
+ """Kick a user for the given reason without notifying the user."""
+ await self.apply_kick(ctx, user, reason, hidden=True, active=False)
+
+ @command(hidden=True, aliases=['shadowban', 'sban'])
+ async def shadow_ban(self, ctx: Context, user: MemberConverter, *, reason: str = None) -> None:
+ """Permanently ban a user for the given reason without notifying the user."""
+ await self.apply_ban(ctx, user, reason, hidden=True)
+
+ # endregion
+ # region: Temporary shadow infractions
+
+ @command(hidden=True, aliases=["shadowtempmute, stempmute", "shadowmute", "smute"])
+ async def shadow_tempmute(self, ctx: Context, user: Member, duration: utils.Expiry, *, reason: str = None) -> None:
+ """
+ Temporarily mute a user for the given reason and duration without notifying the user.
+
+ A unit of time should be appended to the duration.
+ Units (∗case-sensitive):
+ \u2003`y` - years
+ \u2003`m` - months∗
+ \u2003`w` - weeks
+ \u2003`d` - days
+ \u2003`h` - hours
+ \u2003`M` - minutes∗
+ \u2003`s` - seconds
+
+ Alternatively, an ISO 8601 timestamp can be provided for the duration.
+ """
+ await self.apply_mute(ctx, user, reason, expires_at=duration, hidden=True)
+
+ @command(hidden=True, aliases=["shadowtempban, stempban"])
+ async def shadow_tempban(
+ self,
+ ctx: Context,
+ user: MemberConverter,
+ duration: utils.Expiry,
+ *,
+ reason: str = None
+ ) -> None:
+ """
+ Temporarily ban a user for the given reason and duration without notifying the user.
+
+ A unit of time should be appended to the duration.
+ Units (∗case-sensitive):
+ \u2003`y` - years
+ \u2003`m` - months∗
+ \u2003`w` - weeks
+ \u2003`d` - days
+ \u2003`h` - hours
+ \u2003`M` - minutes∗
+ \u2003`s` - seconds
+
+ Alternatively, an ISO 8601 timestamp can be provided for the duration.
+ """
+ await self.apply_ban(ctx, user, reason, expires_at=duration, hidden=True)
+
+ # endregion
+ # region: Remove infractions (un- commands)
+
+ @command()
+ async def unmute(self, ctx: Context, user: MemberConverter) -> None:
+ """Prematurely end the active mute infraction for the user."""
+ await self.pardon_infraction(ctx, "mute", user)
+
+ @command()
+ async def unban(self, ctx: Context, user: MemberConverter) -> None:
+ """Prematurely end the active ban infraction for the user."""
+ await self.pardon_infraction(ctx, "ban", user)
+
+ # endregion
+ # region: Base infraction functions
+
+ async def apply_mute(self, ctx: Context, user: Member, reason: str, **kwargs) -> None:
+ """Apply a mute infraction with kwargs passed to `post_infraction`."""
+ if await utils.has_active_infraction(ctx, user, "mute"):
+ return
+
+ infraction = await utils.post_infraction(ctx, user, "mute", reason, **kwargs)
+ if infraction is None:
+ return
+
+ self.mod_log.ignore(Event.member_update, user.id)
+
+ action = user.add_roles(self._muted_role, reason=reason)
+ await self.apply_infraction(ctx, infraction, user, action)
+
+ @respect_role_hierarchy()
+ async def apply_kick(self, ctx: Context, user: Member, reason: str, **kwargs) -> None:
+ """Apply a kick infraction with kwargs passed to `post_infraction`."""
+ infraction = await utils.post_infraction(ctx, user, "kick", reason, **kwargs)
+ if infraction is None:
+ return
+
+ self.mod_log.ignore(Event.member_remove, user.id)
+
+ action = user.kick(reason=reason)
+ await self.apply_infraction(ctx, infraction, user, action)
+
+ @respect_role_hierarchy()
+ async def apply_ban(self, ctx: Context, user: MemberObject, reason: str, **kwargs) -> None:
+ """Apply a ban infraction with kwargs passed to `post_infraction`."""
+ if await utils.has_active_infraction(ctx, user, "ban"):
+ return
+
+ infraction = await utils.post_infraction(ctx, user, "ban", reason, **kwargs)
+ if infraction is None:
+ return
+
+ self.mod_log.ignore(Event.member_remove, user.id)
+
+ action = ctx.guild.ban(user, reason=reason, delete_message_days=0)
+ await self.apply_infraction(ctx, infraction, user, action)
+
+ # endregion
+ # region: Utility functions
+
+ async def _scheduled_task(self, infraction: utils.Infraction) -> None:
+ """
+ Marks an infraction expired after the delay from time of scheduling to time of expiration.
+
+ At the time of expiration, the infraction is marked as inactive on the website and the
+ expiration task is cancelled.
+ """
+ _id = infraction["id"]
+
+ expiry = dateutil.parser.isoparse(infraction["expires_at"]).replace(tzinfo=None)
+ await time.wait_until(expiry)
+
+ log.debug(f"Marking infraction {_id} as inactive (expired).")
+ await self.deactivate_infraction(infraction)
+
+ async def deactivate_infraction(
+ self,
+ infraction: utils.Infraction,
+ send_log: bool = True
+ ) -> t.Dict[str, str]:
+ """
+ Deactivate an active infraction and return a dictionary of lines to send in a mod log.
+
+ The infraction is removed from Discord, marked as inactive in the database, and has its
+ expiration task cancelled. If `send_log` is True, a mod log is sent for the
+ deactivation of the infraction.
+
+ Supported infraction types are mute and ban. Other types will raise a ValueError.
+ """
+ guild = self.bot.get_guild(constants.Guild.id)
+ mod_role = guild.get_role(constants.Roles.moderator)
+ user_id = infraction["user"]
+ _type = infraction["type"]
+ _id = infraction["id"]
+ reason = f"Infraction #{_id} expired or was pardoned."
+
+ log.debug(f"Marking infraction #{_id} as inactive (expired).")
+
+ log_content = None
+ log_text = {
+ "Member": str(user_id),
+ "Actor": str(self.bot.user),
+ "Reason": infraction["reason"]
+ }
+
+ try:
+ if _type == "mute":
+ user = guild.get_member(user_id)
+ if user:
+ # Remove the muted role.
+ self.mod_log.ignore(Event.member_update, user.id)
+ await user.remove_roles(self._muted_role, reason=reason)
+
+ # DM the user about the expiration.
+ notified = await utils.notify_pardon(
+ user=user,
+ title="You have been unmuted.",
+ content="You may now send messages in the server.",
+ icon_url=utils.INFRACTION_ICONS["mute"][1]
+ )
+
+ log_text["Member"] = f"{user.mention}(`{user.id}`)"
+ log_text["DM"] = "Sent" if notified else "**Failed**"
+ else:
+ log.info(f"Failed to unmute user {user_id}: user not found")
+ log_text["Failure"] = "User was not found in the guild."
+ elif _type == "ban":
+ user = discord.Object(user_id)
+ self.mod_log.ignore(Event.member_unban, user_id)
+ try:
+ await guild.unban(user, reason=reason)
+ except discord.NotFound:
+ log.info(f"Failed to unban user {user_id}: no active ban found on Discord")
+ log_text["Note"] = "No active ban found on Discord."
+ else:
+ raise ValueError(
+ f"Attempted to deactivate an unsupported infraction #{_id} ({_type})!"
+ )
+ except discord.Forbidden:
+ log.warning(f"Failed to deactivate infraction #{_id} ({_type}): bot lacks permissions")
+ log_text["Failure"] = f"The bot lacks permissions to do this (role hierarchy?)"
+ log_content = mod_role.mention
+ except discord.HTTPException as e:
+ log.exception(f"Failed to deactivate infraction #{_id} ({_type})")
+ log_text["Failure"] = f"HTTPException with code {e.code}."
+ log_content = mod_role.mention
+
+ # Check if the user is currently being watched by Big Brother.
+ try:
+ active_watch = await self.bot.api_client.get(
+ "bot/infractions",
+ params={
+ "active": "true",
+ "type": "watch",
+ "user__id": user_id
+ }
+ )
+
+ log_text["Watching"] = "Yes" if active_watch else "No"
+ except ResponseCodeError:
+ log.exception(f"Failed to fetch watch status for user {user_id}")
+ log_text["Watching"] = "Unknown - failed to fetch watch status."
+
+ try:
+ # Mark infraction as inactive in the database.
+ await self.bot.api_client.patch(
+ f"bot/infractions/{_id}",
+ json={"active": False}
+ )
+ except ResponseCodeError as e:
+ log.exception(f"Failed to deactivate infraction #{_id} ({_type})")
+ log_line = f"API request failed with code {e.status}."
+ log_content = mod_role.mention
+
+ # Append to an existing failure message if possible
+ if "Failure" in log_text:
+ log_text["Failure"] += f" {log_line}"
+ else:
+ log_text["Failure"] = log_line
+
+ # Cancel the expiration task.
+ if infraction["expires_at"] is not None:
+ self.cancel_task(infraction["id"])
+
+ # Send a log message to the mod log.
+ if send_log:
+ log_title = f"expiration failed" if "Failure" in log_text else "expired"
+
+ await self.mod_log.send_log_message(
+ icon_url=utils.INFRACTION_ICONS[_type][1],
+ colour=Colours.soft_green,
+ title=f"Infraction {log_title}: {_type}",
+ text="\n".join(f"{k}: {v}" for k, v in log_text.items()),
+ footer=f"ID: {_id}",
+ content=log_content,
+ )
+
+ return log_text
+
+ async def apply_infraction(
+ self,
+ ctx: Context,
+ infraction: utils.Infraction,
+ user: MemberObject,
+ action_coro: t.Optional[t.Awaitable] = None
+ ) -> None:
+ """Apply an infraction to the user, log the infraction, and optionally notify the user."""
+ infr_type = infraction["type"]
+ icon = utils.INFRACTION_ICONS[infr_type][0]
+ reason = infraction["reason"]
+ expiry = infraction["expires_at"]
+
+ if expiry:
+ expiry = time.format_infraction(expiry)
+
+ # Default values for the confirmation message and mod log.
+ confirm_msg = f":ok_hand: applied"
+
+ # Specifying an expiry for a note or warning makes no sense.
+ if infr_type in ("note", "warning"):
+ expiry_msg = ""
+ else:
+ expiry_msg = f" until {expiry}" if expiry else " permanently"
+
+ dm_result = ""
+ dm_log_text = ""
+ expiry_log_text = f"Expires: {expiry}" if expiry else ""
+ log_title = "applied"
+ log_content = None
+
+ # DM the user about the infraction if it's not a shadow/hidden infraction.
+ if not infraction["hidden"]:
+ # Sometimes user is a discord.Object; make it a proper user.
+ await self.bot.fetch_user(user.id)
+
+ # Accordingly display whether the user was successfully notified via DM.
+ if await utils.notify_infraction(user, infr_type, expiry, reason, icon):
+ dm_result = ":incoming_envelope: "
+ dm_log_text = "\nDM: Sent"
+ else:
+ dm_log_text = "\nDM: **Failed**"
+ log_content = ctx.author.mention
+
+ if infraction["actor"] == self.bot.user.id:
+ end_msg = f" (reason: {infraction['reason']})"
+ elif ctx.channel.id not in STAFF_CHANNELS:
+ end_msg = ''
+ else:
+ infractions = await self.bot.api_client.get(
+ "bot/infractions",
+ params={"user__id": str(user.id)}
+ )
+ total = len(infractions)
+ end_msg = f" ({total} infraction{ngettext('', 's', total)} total)"
+
+ # Execute the necessary actions to apply the infraction on Discord.
+ if action_coro:
+ try:
+ await action_coro
+ if expiry:
+ # Schedule the expiration of the infraction.
+ self.schedule_task(ctx.bot.loop, infraction["id"], infraction)
+ except discord.Forbidden:
+ # Accordingly display that applying the infraction failed.
+ confirm_msg = f":x: failed to apply"
+ expiry_msg = ""
+ log_content = ctx.author.mention
+ log_title = "failed to apply"
+
+ # Send a confirmation message to the invoking context.
+ await ctx.send(
+ f"{dm_result}{confirm_msg} **{infr_type}** to {user.mention}{expiry_msg}{end_msg}."
+ )
+
+ # Send a log message to the mod log.
+ await self.mod_log.send_log_message(
+ icon_url=icon,
+ colour=Colours.soft_red,
+ title=f"Infraction {log_title}: {infr_type}",
+ thumbnail=user.avatar_url_as(static_format="png"),
+ text=textwrap.dedent(f"""
+ Member: {user.mention} (`{user.id}`)
+ Actor: {ctx.message.author}{dm_log_text}
+ Reason: {reason}
+ {expiry_log_text}
+ """),
+ content=log_content,
+ footer=f"ID {infraction['id']}"
+ )
+
+ async def pardon_infraction(self, ctx: Context, infr_type: str, user: MemberObject) -> None:
+ """Prematurely end an infraction for a user and log the action in the mod log."""
+ # Check the current active infraction
+ response = await self.bot.api_client.get(
+ 'bot/infractions',
+ params={
+ 'active': 'true',
+ 'type': infr_type,
+ 'user__id': user.id
+ }
+ )
+
+ if not response:
+ await ctx.send(f":x: There's no active {infr_type} infraction for user {user.mention}.")
+ return
+
+ # Deactivate the infraction and cancel its scheduled expiration task.
+ log_text = await self.deactivate_infraction(response[0], send_log=False)
+
+ log_text["Member"] = f"{user.mention}(`{user.id}`)"
+ log_text["Actor"] = str(ctx.message.author)
+ log_content = None
+ footer = f"ID: {response[0]['id']}"
+
+ # If multiple active infractions were found, mark them as inactive in the database
+ # and cancel their expiration tasks.
+ if len(response) > 1:
+ log.warning(f"Found more than one active {infr_type} infraction for user {user.id}")
+
+ footer = f"Infraction IDs: {', '.join(str(infr['id']) for infr in response)}"
+
+ log_note = f"Found multiple **active** {infr_type} infractions in the database."
+ if "Note" in log_text:
+ log_text["Note"] = f" {log_note}"
+ else:
+ log_text["Note"] = log_note
+
+ # deactivate_infraction() is not called again because:
+ # 1. Discord cannot store multiple active bans or assign multiples of the same role
+ # 2. It would send a pardon DM for each active infraction, which is redundant
+ for infraction in response[1:]:
+ _id = infraction['id']
+ try:
+ # Mark infraction as inactive in the database.
+ await self.bot.api_client.patch(
+ f"bot/infractions/{_id}",
+ json={"active": False}
+ )
+ except ResponseCodeError:
+ log.exception(f"Failed to deactivate infraction #{_id} ({infr_type})")
+ # This is simpler and cleaner than trying to concatenate all the errors.
+ log_text["Failure"] = "See bot's logs for details."
+
+ # Cancel pending expiration task.
+ if infraction["expires_at"] is not None:
+ self.cancel_task(infraction["id"])
+
+ # Accordingly display whether the user was successfully notified via DM.
+ dm_emoji = ""
+ if log_text.get("DM") == "Sent":
+ dm_emoji = ":incoming_envelope: "
+ elif "DM" in log_text:
+ # Mention the actor because the DM failed to send.
+ log_content = ctx.author.mention
+
+ # Accordingly display whether the pardon failed.
+ if "Failure" in log_text:
+ confirm_msg = ":x: failed to pardon"
+ log_title = "pardon failed"
+ log_content = ctx.author.mention
+ else:
+ confirm_msg = f":ok_hand: pardoned"
+ log_title = "pardoned"
+
+ # Send a confirmation message to the invoking context.
+ await ctx.send(
+ f"{dm_emoji}{confirm_msg} infraction **{infr_type}** for {user.mention}. "
+ f"{log_text.get('Failure', '')}"
+ )
+
+ # Send a log message to the mod log.
+ await self.mod_log.send_log_message(
+ icon_url=utils.INFRACTION_ICONS[infr_type][1],
+ colour=Colours.soft_green,
+ title=f"Infraction {log_title}: {infr_type}",
+ thumbnail=user.avatar_url_as(static_format="png"),
+ text="\n".join(f"{k}: {v}" for k, v in log_text.items()),
+ footer=footer,
+ content=log_content,
+ )
+
+ # endregion
+
+ # This cannot be static (must have a __func__ attribute).
+ def cog_check(self, ctx: Context) -> bool:
+ """Only allow moderators to invoke the commands in this cog."""
+ return with_role_check(ctx, *constants.MODERATION_ROLES)
+
+ # This cannot be static (must have a __func__ attribute).
+ async def cog_command_error(self, ctx: Context, error: Exception) -> None:
+ """Send a notification to the invoking context on a Union failure."""
+ if isinstance(error, commands.BadUnionArgument):
+ if discord.User in error.converters:
+ await ctx.send(str(error.errors[0]))
+ error.handled = True
diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py
new file mode 100644
index 000000000..44a508436
--- /dev/null
+++ b/bot/cogs/moderation/management.py
@@ -0,0 +1,272 @@
+import asyncio
+import logging
+import textwrap
+import typing as t
+
+import discord
+from discord.ext import commands
+from discord.ext.commands import Context
+
+from bot import constants
+from bot.converters import InfractionSearchQuery
+from bot.pagination import LinePaginator
+from bot.utils import time
+from bot.utils.checks import in_channel_check, with_role_check
+from . import utils
+from .infractions import Infractions
+from .modlog import ModLog
+
+log = logging.getLogger(__name__)
+
+UserConverter = t.Union[discord.User, utils.proxy_user]
+
+
+def permanent_duration(expires_at: str) -> str:
+ """Only allow an expiration to be 'permanent' if it is a string."""
+ expires_at = expires_at.lower()
+ if expires_at != "permanent":
+ raise commands.BadArgument
+ else:
+ return expires_at
+
+
+class ModManagement(commands.Cog):
+ """Management of infractions."""
+
+ category = "Moderation"
+
+ def __init__(self, bot: commands.Bot):
+ self.bot = bot
+
+ @property
+ def mod_log(self) -> ModLog:
+ """Get currently loaded ModLog cog instance."""
+ return self.bot.get_cog("ModLog")
+
+ @property
+ def infractions_cog(self) -> Infractions:
+ """Get currently loaded Infractions cog instance."""
+ return self.bot.get_cog("Infractions")
+
+ # region: Edit infraction commands
+
+ @commands.group(name='infraction', aliases=('infr', 'infractions', 'inf'), invoke_without_command=True)
+ async def infraction_group(self, ctx: Context) -> None:
+ """Infraction manipulation commands."""
+ await ctx.invoke(self.bot.get_command("help"), "infraction")
+
+ @infraction_group.command(name='edit')
+ async def infraction_edit(
+ self,
+ ctx: Context,
+ infraction_id: int,
+ duration: t.Union[utils.Expiry, permanent_duration, None],
+ *,
+ reason: str = None
+ ) -> None:
+ """
+ Edit the duration and/or the reason of an infraction.
+
+ Durations are relative to the time of updating and should be appended with a unit of time.
+ Units (∗case-sensitive):
+ \u2003`y` - years
+ \u2003`m` - months∗
+ \u2003`w` - weeks
+ \u2003`d` - days
+ \u2003`h` - hours
+ \u2003`M` - minutes∗
+ \u2003`s` - seconds
+
+ Use "permanent" to mark the infraction as permanent. Alternatively, an ISO 8601 timestamp
+ can be provided for the duration.
+ """
+ if duration is None and reason is None:
+ # Unlike UserInputError, the error handler will show a specified message for BadArgument
+ raise commands.BadArgument("Neither a new expiry nor a new reason was specified.")
+
+ # Retrieve the previous infraction for its information.
+ old_infraction = await self.bot.api_client.get(f'bot/infractions/{infraction_id}')
+
+ request_data = {}
+ confirm_messages = []
+ log_text = ""
+
+ if duration == "permanent":
+ request_data['expires_at'] = None
+ confirm_messages.append("marked as permanent")
+ elif duration is not None:
+ request_data['expires_at'] = duration.isoformat()
+ expiry = duration.strftime(time.INFRACTION_FORMAT)
+ confirm_messages.append(f"set to expire on {expiry}")
+ else:
+ confirm_messages.append("expiry unchanged")
+
+ if reason:
+ request_data['reason'] = reason
+ confirm_messages.append("set a new reason")
+ log_text += f"""
+ Previous reason: {old_infraction['reason']}
+ New reason: {reason}
+ """.rstrip()
+ else:
+ confirm_messages.append("reason unchanged")
+
+ # Update the infraction
+ new_infraction = await self.bot.api_client.patch(
+ f'bot/infractions/{infraction_id}',
+ json=request_data,
+ )
+
+ # Re-schedule infraction if the expiration has been updated
+ if 'expires_at' in request_data:
+ self.infractions_cog.cancel_task(new_infraction['id'])
+ loop = asyncio.get_event_loop()
+ self.infractions_cog.schedule_task(loop, new_infraction['id'], new_infraction)
+
+ log_text += f"""
+ Previous expiry: {old_infraction['expires_at'] or "Permanent"}
+ New expiry: {new_infraction['expires_at'] or "Permanent"}
+ """.rstrip()
+
+ await ctx.send(f":ok_hand: Updated infraction: {' & '.join(confirm_messages)}")
+
+ # Get information about the infraction's user
+ user_id = new_infraction['user']
+ user = ctx.guild.get_member(user_id)
+
+ if user:
+ user_text = f"{user.mention} (`{user.id}`)"
+ thumbnail = user.avatar_url_as(static_format="png")
+ else:
+ user_text = f"`{user_id}`"
+ thumbnail = None
+
+ # The infraction's actor
+ actor_id = new_infraction['actor']
+ actor = ctx.guild.get_member(actor_id) or f"`{actor_id}`"
+
+ await self.mod_log.send_log_message(
+ icon_url=constants.Icons.pencil,
+ colour=discord.Colour.blurple(),
+ title="Infraction edited",
+ thumbnail=thumbnail,
+ text=textwrap.dedent(f"""
+ Member: {user_text}
+ Actor: {actor}
+ Edited by: {ctx.message.author}{log_text}
+ """)
+ )
+
+ # endregion
+ # region: Search infractions
+
+ @infraction_group.group(name="search", invoke_without_command=True)
+ async def infraction_search_group(self, ctx: Context, query: InfractionSearchQuery) -> None:
+ """Searches for infractions in the database."""
+ if isinstance(query, discord.User):
+ await ctx.invoke(self.search_user, query)
+ else:
+ await ctx.invoke(self.search_reason, query)
+
+ @infraction_search_group.command(name="user", aliases=("member", "id"))
+ async def search_user(self, ctx: Context, user: UserConverter) -> None:
+ """Search for infractions by member."""
+ infraction_list = await self.bot.api_client.get(
+ 'bot/infractions',
+ params={'user__id': str(user.id)}
+ )
+ embed = discord.Embed(
+ title=f"Infractions for {user} ({len(infraction_list)} total)",
+ colour=discord.Colour.orange()
+ )
+ await self.send_infraction_list(ctx, embed, infraction_list)
+
+ @infraction_search_group.command(name="reason", aliases=("match", "regex", "re"))
+ async def search_reason(self, ctx: Context, reason: str) -> None:
+ """Search for infractions by their reason. Use Re2 for matching."""
+ infraction_list = await self.bot.api_client.get(
+ 'bot/infractions',
+ params={'search': reason}
+ )
+ embed = discord.Embed(
+ title=f"Infractions matching `{reason}` ({len(infraction_list)} total)",
+ colour=discord.Colour.orange()
+ )
+ await self.send_infraction_list(ctx, embed, infraction_list)
+
+ # endregion
+ # region: Utility functions
+
+ async def send_infraction_list(
+ self,
+ ctx: Context,
+ embed: discord.Embed,
+ infractions: t.Iterable[utils.Infraction]
+ ) -> None:
+ """Send a paginated embed of infractions for the specified user."""
+ if not infractions:
+ await ctx.send(f":warning: No infractions could be found for that query.")
+ return
+
+ lines = tuple(
+ self.infraction_to_string(infraction)
+ for infraction in infractions
+ )
+
+ await LinePaginator.paginate(
+ lines,
+ ctx=ctx,
+ embed=embed,
+ empty=True,
+ max_lines=3,
+ max_size=1000
+ )
+
+ def infraction_to_string(self, infraction: utils.Infraction) -> str:
+ """Convert the infraction object to a string representation."""
+ actor_id = infraction["actor"]
+ guild = self.bot.get_guild(constants.Guild.id)
+ actor = guild.get_member(actor_id)
+ active = infraction["active"]
+ user_id = infraction["user"]
+ hidden = infraction["hidden"]
+ created = time.format_infraction(infraction["inserted_at"])
+ if infraction["expires_at"] is None:
+ expires = "*Permanent*"
+ else:
+ expires = time.format_infraction(infraction["expires_at"])
+
+ lines = textwrap.dedent(f"""
+ {"**===============**" if active else "==============="}
+ Status: {"__**Active**__" if active else "Inactive"}
+ User: {self.bot.get_user(user_id)} (`{user_id}`)
+ Type: **{infraction["type"]}**
+ Shadow: {hidden}
+ Reason: {infraction["reason"] or "*None*"}
+ Created: {created}
+ Expires: {expires}
+ Actor: {actor.mention if actor else actor_id}
+ ID: `{infraction["id"]}`
+ {"**===============**" if active else "==============="}
+ """)
+
+ return lines.strip()
+
+ # endregion
+
+ # This cannot be static (must have a __func__ attribute).
+ def cog_check(self, ctx: Context) -> bool:
+ """Only allow moderators from moderator channels to invoke the commands in this cog."""
+ checks = [
+ with_role_check(ctx, *constants.MODERATION_ROLES),
+ in_channel_check(ctx, *constants.MODERATION_CHANNELS)
+ ]
+ return all(checks)
+
+ # This cannot be static (must have a __func__ attribute).
+ async def cog_command_error(self, ctx: Context, error: Exception) -> None:
+ """Send a notification to the invoking context on a Union failure."""
+ if isinstance(error, commands.BadUnionArgument):
+ if discord.User in error.converters:
+ await ctx.send(str(error.errors[0]))
+ error.handled = True
diff --git a/bot/cogs/modlog.py b/bot/cogs/moderation/modlog.py
index 68424d268..88f2b6c67 100644
--- a/bot/cogs/modlog.py
+++ b/bot/cogs/moderation/modlog.py
@@ -1,30 +1,26 @@
import asyncio
import logging
+import typing as t
from datetime import datetime
-from typing import List, Optional, Union
+import discord
from dateutil.relativedelta import relativedelta
from deepdiff import DeepDiff
-from discord import (
- CategoryChannel, Colour, Embed, File, Guild,
- Member, Message, NotFound, RawMessageDeleteEvent,
- RawMessageUpdateEvent, Role, TextChannel, User, VoiceChannel
-)
+from discord import Colour
from discord.abc import GuildChannel
from discord.ext.commands import Bot, Cog, Context
-from bot.constants import (
- Channels, Colours, Emojis, Event, Guild as GuildConstant, Icons, URLs
-)
+from bot.constants import Channels, Colours, Emojis, Event, Guild as GuildConstant, Icons, URLs
from bot.utils.time import humanize_delta
+from .utils import UserTypes
log = logging.getLogger(__name__)
-GUILD_CHANNEL = Union[CategoryChannel, TextChannel, VoiceChannel]
+GUILD_CHANNEL = t.Union[discord.CategoryChannel, discord.TextChannel, discord.VoiceChannel]
CHANNEL_CHANGES_UNSUPPORTED = ("permissions",)
CHANNEL_CHANGES_SUPPRESSED = ("_overwrites", "position")
-MEMBER_CHANGES_SUPPRESSED = ("status", "activities", "_client_status")
+MEMBER_CHANGES_SUPPRESSED = ("status", "activities", "_client_status", "nick")
ROLE_CHANGES_UNSUPPORTED = ("colour", "permissions")
@@ -38,7 +34,7 @@ class ModLog(Cog, name="ModLog"):
self._cached_deletes = []
self._cached_edits = []
- async def upload_log(self, messages: List[Message], actor_id: int) -> str:
+ async def upload_log(self, messages: t.List[discord.Message], actor_id: int) -> str:
"""
Uploads the log data to the database via an API endpoint for uploading logs.
@@ -73,23 +69,23 @@ class ModLog(Cog, name="ModLog"):
self._ignored[event].append(item)
async def send_log_message(
- self,
- icon_url: Optional[str],
- colour: Colour,
- title: Optional[str],
- text: str,
- thumbnail: Optional[str] = None,
- channel_id: int = Channels.modlog,
- ping_everyone: bool = False,
- files: Optional[List[File]] = None,
- content: Optional[str] = None,
- additional_embeds: Optional[List[Embed]] = None,
- additional_embeds_msg: Optional[str] = None,
- timestamp_override: Optional[datetime] = None,
- footer: Optional[str] = None,
+ self,
+ icon_url: t.Optional[str],
+ colour: t.Union[discord.Colour, int],
+ title: t.Optional[str],
+ text: str,
+ thumbnail: t.Optional[t.Union[str, discord.Asset]] = None,
+ channel_id: int = Channels.modlog,
+ ping_everyone: bool = False,
+ files: t.Optional[t.List[discord.File]] = None,
+ content: t.Optional[str] = None,
+ additional_embeds: t.Optional[t.List[discord.Embed]] = None,
+ additional_embeds_msg: t.Optional[str] = None,
+ timestamp_override: t.Optional[datetime] = None,
+ footer: t.Optional[str] = None,
) -> Context:
"""Generate log embed and send to logging channel."""
- embed = Embed(description=text)
+ embed = discord.Embed(description=text)
if title and icon_url:
embed.set_author(name=title, icon_url=icon_url)
@@ -126,10 +122,10 @@ class ModLog(Cog, name="ModLog"):
if channel.guild.id != GuildConstant.id:
return
- if isinstance(channel, CategoryChannel):
+ if isinstance(channel, discord.CategoryChannel):
title = "Category created"
message = f"{channel.name} (`{channel.id}`)"
- elif isinstance(channel, VoiceChannel):
+ elif isinstance(channel, discord.VoiceChannel):
title = "Voice channel created"
if channel.category:
@@ -144,7 +140,7 @@ class ModLog(Cog, name="ModLog"):
else:
message = f"{channel.name} (`{channel.id}`)"
- await self.send_log_message(Icons.hash_green, Colour(Colours.soft_green), title, message)
+ await self.send_log_message(Icons.hash_green, Colours.soft_green, title, message)
@Cog.listener()
async def on_guild_channel_delete(self, channel: GUILD_CHANNEL) -> None:
@@ -152,20 +148,20 @@ class ModLog(Cog, name="ModLog"):
if channel.guild.id != GuildConstant.id:
return
- if isinstance(channel, CategoryChannel):
+ if isinstance(channel, discord.CategoryChannel):
title = "Category deleted"
- elif isinstance(channel, VoiceChannel):
+ elif isinstance(channel, discord.VoiceChannel):
title = "Voice channel deleted"
else:
title = "Text channel deleted"
- if channel.category and not isinstance(channel, CategoryChannel):
+ if channel.category and not isinstance(channel, discord.CategoryChannel):
message = f"{channel.category}/{channel.name} (`{channel.id}`)"
else:
message = f"{channel.name} (`{channel.id}`)"
await self.send_log_message(
- Icons.hash_red, Colour(Colours.soft_red),
+ Icons.hash_red, Colours.soft_red,
title, message
)
@@ -230,29 +226,29 @@ class ModLog(Cog, name="ModLog"):
)
@Cog.listener()
- async def on_guild_role_create(self, role: Role) -> None:
+ async def on_guild_role_create(self, role: discord.Role) -> None:
"""Log role create event to mod log."""
if role.guild.id != GuildConstant.id:
return
await self.send_log_message(
- Icons.crown_green, Colour(Colours.soft_green),
+ Icons.crown_green, Colours.soft_green,
"Role created", f"`{role.id}`"
)
@Cog.listener()
- async def on_guild_role_delete(self, role: Role) -> None:
+ async def on_guild_role_delete(self, role: discord.Role) -> None:
"""Log role delete event to mod log."""
if role.guild.id != GuildConstant.id:
return
await self.send_log_message(
- Icons.crown_red, Colour(Colours.soft_red),
+ Icons.crown_red, Colours.soft_red,
"Role removed", f"{role.name} (`{role.id}`)"
)
@Cog.listener()
- async def on_guild_role_update(self, before: Role, after: Role) -> None:
+ async def on_guild_role_update(self, before: discord.Role, after: discord.Role) -> None:
"""Log role update event to mod log."""
if before.guild.id != GuildConstant.id:
return
@@ -305,7 +301,7 @@ class ModLog(Cog, name="ModLog"):
)
@Cog.listener()
- async def on_guild_update(self, before: Guild, after: Guild) -> None:
+ async def on_guild_update(self, before: discord.Guild, after: discord.Guild) -> None:
"""Log guild update event to mod log."""
if before.id != GuildConstant.id:
return
@@ -356,8 +352,8 @@ class ModLog(Cog, name="ModLog"):
)
@Cog.listener()
- async def on_member_ban(self, guild: Guild, member: Union[Member, User]) -> None:
- """Log ban event to mod log."""
+ async def on_member_ban(self, guild: discord.Guild, member: UserTypes) -> None:
+ """Log ban event to user log."""
if guild.id != GuildConstant.id:
return
@@ -366,19 +362,19 @@ class ModLog(Cog, name="ModLog"):
return
await self.send_log_message(
- Icons.user_ban, Colour(Colours.soft_red),
- "User banned", f"{member.name}#{member.discriminator} (`{member.id}`)",
+ Icons.user_ban, Colours.soft_red,
+ "User banned", f"{member} (`{member.id}`)",
thumbnail=member.avatar_url_as(static_format="png"),
- channel_id=Channels.modlog
+ channel_id=Channels.userlog
)
@Cog.listener()
- async def on_member_join(self, member: Member) -> None:
+ async def on_member_join(self, member: discord.Member) -> None:
"""Log member join event to user log."""
if member.guild.id != GuildConstant.id:
return
- message = f"{member.name}#{member.discriminator} (`{member.id}`)"
+ message = f"{member} (`{member.id}`)"
now = datetime.utcnow()
difference = abs(relativedelta(now, member.created_at))
@@ -388,14 +384,14 @@ class ModLog(Cog, name="ModLog"):
message = f"{Emojis.new} {message}"
await self.send_log_message(
- Icons.sign_in, Colour(Colours.soft_green),
+ Icons.sign_in, Colours.soft_green,
"User joined", message,
thumbnail=member.avatar_url_as(static_format="png"),
channel_id=Channels.userlog
)
@Cog.listener()
- async def on_member_remove(self, member: Member) -> None:
+ async def on_member_remove(self, member: discord.Member) -> None:
"""Log member leave event to user log."""
if member.guild.id != GuildConstant.id:
return
@@ -405,14 +401,14 @@ class ModLog(Cog, name="ModLog"):
return
await self.send_log_message(
- Icons.sign_out, Colour(Colours.soft_red),
- "User left", f"{member.name}#{member.discriminator} (`{member.id}`)",
+ Icons.sign_out, Colours.soft_red,
+ "User left", f"{member} (`{member.id}`)",
thumbnail=member.avatar_url_as(static_format="png"),
channel_id=Channels.userlog
)
@Cog.listener()
- async def on_member_unban(self, guild: Guild, member: User) -> None:
+ async def on_member_unban(self, guild: discord.Guild, member: discord.User) -> None:
"""Log member unban event to mod log."""
if guild.id != GuildConstant.id:
return
@@ -423,13 +419,13 @@ class ModLog(Cog, name="ModLog"):
await self.send_log_message(
Icons.user_unban, Colour.blurple(),
- "User unbanned", f"{member.name}#{member.discriminator} (`{member.id}`)",
+ "User unbanned", f"{member} (`{member.id}`)",
thumbnail=member.avatar_url_as(static_format="png"),
channel_id=Channels.modlog
)
@Cog.listener()
- async def on_member_update(self, before: Member, after: Member) -> None:
+ async def on_member_update(self, before: discord.Member, after: discord.Member) -> None:
"""Log member update event to user log."""
if before.guild.id != GuildConstant.id:
return
@@ -502,6 +498,11 @@ class ModLog(Cog, name="ModLog"):
f"**Discriminator:** `{before.discriminator}` **->** `{after.discriminator}`"
)
+ if before.display_name != after.display_name:
+ changes.append(
+ f"**Display name:** `{before.display_name}` **->** `{after.display_name}`"
+ )
+
if not changes:
return
@@ -510,7 +511,7 @@ class ModLog(Cog, name="ModLog"):
for item in sorted(changes):
message += f"{Emojis.bullet} {item}\n"
- message = f"**{after.name}#{after.discriminator}** (`{after.id}`)\n{message}"
+ message = f"**{after}** (`{after.id}`)\n{message}"
await self.send_log_message(
Icons.user_update, Colour.blurple(),
@@ -520,7 +521,7 @@ class ModLog(Cog, name="ModLog"):
)
@Cog.listener()
- async def on_message_delete(self, message: Message) -> None:
+ async def on_message_delete(self, message: discord.Message) -> None:
"""Log message delete event to message change log."""
channel = message.channel
author = message.author
@@ -539,14 +540,14 @@ class ModLog(Cog, name="ModLog"):
if channel.category:
response = (
- f"**Author:** {author.name}#{author.discriminator} (`{author.id}`)\n"
+ f"**Author:** {author} (`{author.id}`)\n"
f"**Channel:** {channel.category}/#{channel.name} (`{channel.id}`)\n"
f"**Message ID:** `{message.id}`\n"
"\n"
)
else:
response = (
- f"**Author:** {author.name}#{author.discriminator} (`{author.id}`)\n"
+ f"**Author:** {author} (`{author.id}`)\n"
f"**Channel:** #{channel.name} (`{channel.id}`)\n"
f"**Message ID:** `{message.id}`\n"
"\n"
@@ -576,7 +577,7 @@ class ModLog(Cog, name="ModLog"):
)
@Cog.listener()
- async def on_raw_message_delete(self, event: RawMessageDeleteEvent) -> None:
+ async def on_raw_message_delete(self, event: discord.RawMessageDeleteEvent) -> None:
"""Log raw message delete event to message change log."""
if event.guild_id != GuildConstant.id or event.channel_id in GuildConstant.ignored:
return
@@ -610,14 +611,14 @@ class ModLog(Cog, name="ModLog"):
)
await self.send_log_message(
- Icons.message_delete, Colour(Colours.soft_red),
+ Icons.message_delete, Colours.soft_red,
"Message deleted",
response,
channel_id=Channels.message_log
)
@Cog.listener()
- async def on_message_edit(self, before: Message, after: Message) -> None:
+ async def on_message_edit(self, before: discord.Message, after: discord.Message) -> None:
"""Log message edit event to message change log."""
if (
not before.guild
@@ -637,7 +638,7 @@ class ModLog(Cog, name="ModLog"):
if channel.category:
before_response = (
- f"**Author:** {author.name}#{author.discriminator} (`{author.id}`)\n"
+ f"**Author:** {author} (`{author.id}`)\n"
f"**Channel:** {channel.category}/#{channel.name} (`{channel.id}`)\n"
f"**Message ID:** `{before.id}`\n"
"\n"
@@ -645,7 +646,7 @@ class ModLog(Cog, name="ModLog"):
)
after_response = (
- f"**Author:** {author.name}#{author.discriminator} (`{author.id}`)\n"
+ f"**Author:** {author} (`{author.id}`)\n"
f"**Channel:** {channel.category}/#{channel.name} (`{channel.id}`)\n"
f"**Message ID:** `{before.id}`\n"
"\n"
@@ -653,7 +654,7 @@ class ModLog(Cog, name="ModLog"):
)
else:
before_response = (
- f"**Author:** {author.name}#{author.discriminator} (`{author.id}`)\n"
+ f"**Author:** {author} (`{author.id}`)\n"
f"**Channel:** #{channel.name} (`{channel.id}`)\n"
f"**Message ID:** `{before.id}`\n"
"\n"
@@ -661,7 +662,7 @@ class ModLog(Cog, name="ModLog"):
)
after_response = (
- f"**Author:** {author.name}#{author.discriminator} (`{author.id}`)\n"
+ f"**Author:** {author} (`{author.id}`)\n"
f"**Channel:** #{channel.name} (`{channel.id}`)\n"
f"**Message ID:** `{before.id}`\n"
"\n"
@@ -692,12 +693,12 @@ class ModLog(Cog, name="ModLog"):
)
@Cog.listener()
- async def on_raw_message_edit(self, event: RawMessageUpdateEvent) -> None:
+ async def on_raw_message_edit(self, event: discord.RawMessageUpdateEvent) -> None:
"""Log raw message edit event to message change log."""
try:
channel = self.bot.get_channel(int(event.data["channel_id"]))
message = await channel.fetch_message(event.message_id)
- except NotFound: # Was deleted before we got the event
+ except discord.NotFound: # Was deleted before we got the event
return
if (
@@ -720,7 +721,7 @@ class ModLog(Cog, name="ModLog"):
if channel.category:
before_response = (
- f"**Author:** {author.name}#{author.discriminator} (`{author.id}`)\n"
+ f"**Author:** {author} (`{author.id}`)\n"
f"**Channel:** {channel.category}/#{channel.name} (`{channel.id}`)\n"
f"**Message ID:** `{message.id}`\n"
"\n"
@@ -728,7 +729,7 @@ class ModLog(Cog, name="ModLog"):
)
after_response = (
- f"**Author:** {author.name}#{author.discriminator} (`{author.id}`)\n"
+ f"**Author:** {author} (`{author.id}`)\n"
f"**Channel:** {channel.category}/#{channel.name} (`{channel.id}`)\n"
f"**Message ID:** `{message.id}`\n"
"\n"
@@ -736,7 +737,7 @@ class ModLog(Cog, name="ModLog"):
)
else:
before_response = (
- f"**Author:** {author.name}#{author.discriminator} (`{author.id}`)\n"
+ f"**Author:** {author} (`{author.id}`)\n"
f"**Channel:** #{channel.name} (`{channel.id}`)\n"
f"**Message ID:** `{message.id}`\n"
"\n"
@@ -744,7 +745,7 @@ class ModLog(Cog, name="ModLog"):
)
after_response = (
- f"**Author:** {author.name}#{author.discriminator} (`{author.id}`)\n"
+ f"**Author:** {author} (`{author.id}`)\n"
f"**Channel:** #{channel.name} (`{channel.id}`)\n"
f"**Message ID:** `{message.id}`\n"
"\n"
@@ -760,9 +761,3 @@ class ModLog(Cog, name="ModLog"):
Icons.message_edit, Colour.blurple(), "Message edited (After)",
after_response, channel_id=Channels.message_log
)
-
-
-def setup(bot: Bot) -> None:
- """Mod log cog load."""
- bot.add_cog(ModLog(bot))
- log.info("Cog loaded: ModLog")
diff --git a/bot/cogs/superstarify/__init__.py b/bot/cogs/moderation/superstarify.py
index 87021eded..82f8621fc 100644
--- a/bot/cogs/superstarify/__init__.py
+++ b/bot/cogs/moderation/superstarify.py
@@ -1,21 +1,23 @@
+import json
import logging
import random
+from pathlib import Path
from discord import Colour, Embed, Member
from discord.errors import Forbidden
from discord.ext.commands import Bot, Cog, Context, command
-from bot.cogs.moderation import Moderation
-from bot.cogs.modlog import ModLog
-from bot.cogs.superstarify.stars import get_nick
-from bot.constants import Icons, MODERATION_ROLES, POSITIVE_REPLIES
-from bot.converters import Duration
-from bot.decorators import with_role
-from bot.utils.moderation import post_infraction
+from bot import constants
+from bot.utils.checks import with_role_check
from bot.utils.time import format_infraction
+from . import utils
+from .modlog import ModLog
log = logging.getLogger(__name__)
-NICKNAME_POLICY_URL = "https://pythondiscord.com/pages/rules/#wiki-toc-nickname-policy"
+NICKNAME_POLICY_URL = "https://pythondiscord.com/pages/rules/#nickname-policy"
+
+with Path("bot/resources/stars.json").open(encoding="utf-8") as stars_file:
+ STAR_NAMES = json.load(stars_file)
class Superstarify(Cog):
@@ -25,11 +27,6 @@ class Superstarify(Cog):
self.bot = bot
@property
- def moderation(self) -> Moderation:
- """Get currently loaded Moderation cog instance."""
- return self.bot.get_cog("Moderation")
-
- @property
def modlog(self) -> ModLog:
"""Get currently loaded ModLog cog instance."""
return self.bot.get_cog("ModLog")
@@ -62,7 +59,7 @@ class Superstarify(Cog):
if active_superstarifies:
[infraction] = active_superstarifies
- forced_nick = get_nick(infraction['id'], before.id)
+ forced_nick = self.get_nick(infraction['id'], before.id)
if after.display_name == forced_nick:
return # Nick change was triggered by this event. Ignore.
@@ -108,7 +105,7 @@ class Superstarify(Cog):
if active_superstarifies:
[infraction] = active_superstarifies
- forced_nick = get_nick(infraction['id'], member.id)
+ forced_nick = self.get_nick(infraction['id'], member.id)
await member.edit(nick=forced_nick)
end_timestamp_human = format_infraction(infraction['expires_at'])
@@ -132,13 +129,13 @@ class Superstarify(Cog):
# Log to the mod_log channel
log.trace("Logging to the #mod-log channel. This could fail because of channel permissions.")
mod_log_message = (
- f"**{member.name}#{member.discriminator}** (`{member.id}`)\n\n"
+ f"**{member}** (`{member.id}`)\n\n"
f"Superstarified member potentially tried to escape the prison.\n"
f"Restored enforced nickname: `{forced_nick}`\n"
f"Superstardom ends: **{end_timestamp_human}**"
)
await self.modlog.send_log_message(
- icon_url=Icons.user_update,
+ icon_url=constants.Icons.user_update,
colour=Colour.gold(),
title="Superstar member rejoined server",
text=mod_log_message,
@@ -146,45 +143,39 @@ class Superstarify(Cog):
)
@command(name='superstarify', aliases=('force_nick', 'star'))
- @with_role(*MODERATION_ROLES)
- async def superstarify(
- self, ctx: Context, member: Member, expiration: Duration, reason: str = None
- ) -> None:
+ async def superstarify(self, ctx: Context, member: Member, duration: utils.Expiry, reason: str = None) -> None:
"""
Force a random superstar name (like Taylor Swift) to be the user's nickname for a specified duration.
- An optional reason can be provided.
+ A unit of time should be appended to the duration.
+ Units (∗case-sensitive):
+ \u2003`y` - years
+ \u2003`m` - months∗
+ \u2003`w` - weeks
+ \u2003`d` - days
+ \u2003`h` - hours
+ \u2003`M` - minutes∗
+ \u2003`s` - seconds
+
+ Alternatively, an ISO 8601 timestamp can be provided for the duration.
- If no reason is given, the original name will be shown in a generated reason.
+ An optional reason can be provided. If no reason is given, the original name will be shown
+ in a generated reason.
"""
- active_superstarifies = await self.bot.api_client.get(
- 'bot/infractions',
- params={
- 'active': 'true',
- 'type': 'superstar',
- 'user__id': str(member.id)
- }
- )
- if active_superstarifies:
- await ctx.send(
- ":x: According to my records, this user is already superstarified. "
- f"See infraction **#{active_superstarifies[0]['id']}**."
- )
+ if await utils.has_active_infraction(ctx, member, "superstar"):
return
- infraction = await post_infraction(
- ctx, member,
- type='superstar', reason=reason or ('old nick: ' + member.display_name),
- expires_at=expiration
- )
- forced_nick = get_nick(infraction['id'], member.id)
+ reason = reason or ('old nick: ' + member.display_name)
+ infraction = await utils.post_infraction(ctx, member, 'superstar', reason, expires_at=duration)
+ forced_nick = self.get_nick(infraction['id'], member.id)
+ expiry_str = format_infraction(infraction["expires_at"])
embed = Embed()
embed.title = "Congratulations!"
embed.description = (
f"Your previous nickname, **{member.display_name}**, was so bad that we have decided to change it. "
f"Your new nickname will be **{forced_nick}**.\n\n"
- f"You will be unable to change your nickname until \n**{expiration}**.\n\n"
+ f"You will be unable to change your nickname until \n**{expiry_str}**.\n\n"
"If you're confused by this, please read our "
f"[official nickname policy]({NICKNAME_POLICY_URL})."
)
@@ -192,24 +183,24 @@ class Superstarify(Cog):
# Log to the mod_log channel
log.trace("Logging to the #mod-log channel. This could fail because of channel permissions.")
mod_log_message = (
- f"**{member.name}#{member.discriminator}** (`{member.id}`)\n\n"
+ f"**{member}** (`{member.id}`)\n\n"
f"Superstarified by **{ctx.author.name}**\n"
f"Old nickname: `{member.display_name}`\n"
f"New nickname: `{forced_nick}`\n"
- f"Superstardom ends: **{expiration}**"
+ f"Superstardom ends: **{expiry_str}**"
)
await self.modlog.send_log_message(
- icon_url=Icons.user_update,
+ icon_url=constants.Icons.user_update,
colour=Colour.gold(),
title="Member Achieved Superstardom",
text=mod_log_message,
thumbnail=member.avatar_url_as(static_format="png")
)
- await self.moderation.notify_infraction(
+ await utils.notify_infraction(
user=member,
infr_type="Superstarify",
- expires_at=expiration,
+ expires_at=expiry_str,
reason=f"Your nickname didn't comply with our [nickname policy]({NICKNAME_POLICY_URL})."
)
@@ -219,7 +210,6 @@ class Superstarify(Cog):
await ctx.send(embed=embed)
@command(name='unsuperstarify', aliases=('release_nick', 'unstar'))
- @with_role(*MODERATION_ROLES)
async def unsuperstarify(self, ctx: Context, member: Member) -> None:
"""Remove the superstarify entry from our database, allowing the user to change their nickname."""
log.debug(f"Attempting to unsuperstarify the following user: {member.display_name}")
@@ -247,9 +237,9 @@ class Superstarify(Cog):
embed = Embed()
embed.description = "User has been released from superstar-prison."
- embed.title = random.choice(POSITIVE_REPLIES)
+ embed.title = random.choice(constants.POSITIVE_REPLIES)
- await self.moderation.notify_pardon(
+ await utils.notify_pardon(
user=member,
title="You are no longer superstarified.",
content="You may now change your nickname on the server."
@@ -257,8 +247,13 @@ class Superstarify(Cog):
log.trace(f"{member.display_name} was successfully released from superstar-prison.")
await ctx.send(embed=embed)
+ @staticmethod
+ def get_nick(infraction_id: int, member_id: int) -> str:
+ """Randomly select a nickname from the Superstarify nickname list."""
+ rng = random.Random(str(infraction_id) + str(member_id))
+ return rng.choice(STAR_NAMES)
-def setup(bot: Bot) -> None:
- """Superstarify cog load."""
- bot.add_cog(Superstarify(bot))
- log.info("Cog loaded: Superstarify")
+ # This cannot be static (must have a __func__ attribute).
+ def cog_check(self, ctx: Context) -> bool:
+ """Only allow moderators to invoke the commands in this cog."""
+ return with_role_check(ctx, *constants.MODERATION_ROLES)
diff --git a/bot/cogs/moderation/utils.py b/bot/cogs/moderation/utils.py
new file mode 100644
index 000000000..788a40d40
--- /dev/null
+++ b/bot/cogs/moderation/utils.py
@@ -0,0 +1,172 @@
+import logging
+import textwrap
+import typing as t
+from datetime import datetime
+
+import discord
+from discord.ext import commands
+from discord.ext.commands import Context
+
+from bot.api import ResponseCodeError
+from bot.constants import Colours, Icons
+from bot.converters import Duration, ISODateTime
+
+log = logging.getLogger(__name__)
+
+# apply icon, pardon icon
+INFRACTION_ICONS = {
+ "mute": (Icons.user_mute, Icons.user_unmute),
+ "kick": (Icons.sign_out, None),
+ "ban": (Icons.user_ban, Icons.user_unban),
+ "warning": (Icons.user_warn, None),
+ "note": (Icons.user_warn, None),
+}
+RULES_URL = "https://pythondiscord.com/pages/rules"
+APPEALABLE_INFRACTIONS = ("ban", "mute")
+
+UserTypes = t.Union[discord.Member, discord.User]
+MemberObject = t.Union[UserTypes, discord.Object]
+Infraction = t.Dict[str, t.Union[str, int, bool]]
+Expiry = t.Union[Duration, ISODateTime]
+
+
+def proxy_user(user_id: str) -> discord.Object:
+ """
+ Create a proxy user object from the given id.
+
+ Used when a Member or User object cannot be resolved.
+ """
+ try:
+ user_id = int(user_id)
+ except ValueError:
+ raise commands.BadArgument
+
+ user = discord.Object(user_id)
+ user.mention = user.id
+ user.avatar_url_as = lambda static_format: None
+
+ return user
+
+
+async def post_infraction(
+ ctx: Context,
+ user: MemberObject,
+ infr_type: str,
+ reason: str,
+ expires_at: datetime = None,
+ hidden: bool = False,
+ active: bool = True,
+) -> t.Optional[dict]:
+ """Posts an infraction to the API."""
+ payload = {
+ "actor": ctx.message.author.id,
+ "hidden": hidden,
+ "reason": reason,
+ "type": infr_type,
+ "user": user.id,
+ "active": active
+ }
+ if expires_at:
+ payload['expires_at'] = expires_at.isoformat()
+
+ try:
+ response = await ctx.bot.api_client.post('bot/infractions', json=payload)
+ except ResponseCodeError as exp:
+ if exp.status == 400 and 'user' in exp.response_json:
+ log.info(
+ f"{ctx.author} tried to add a {infr_type} infraction to `{user.id}`, "
+ "but that user id was not found in the database."
+ )
+ await ctx.send(
+ f":x: Cannot add infraction, the specified user is not known to the database."
+ )
+ return
+ else:
+ log.exception("An unexpected ResponseCodeError occurred while adding an infraction:")
+ await ctx.send(":x: There was an error adding the infraction.")
+ return
+
+ return response
+
+
+async def has_active_infraction(ctx: Context, user: MemberObject, infr_type: str) -> bool:
+ """Checks if a user already has an active infraction of the given type."""
+ active_infractions = await ctx.bot.api_client.get(
+ 'bot/infractions',
+ params={
+ 'active': 'true',
+ 'type': infr_type,
+ 'user__id': str(user.id)
+ }
+ )
+ if active_infractions:
+ await ctx.send(
+ f":x: According to my records, this user already has a {infr_type} infraction. "
+ f"See infraction **#{active_infractions[0]['id']}**."
+ )
+ return True
+ else:
+ return False
+
+
+async def notify_infraction(
+ user: UserTypes,
+ infr_type: str,
+ expires_at: t.Optional[str] = None,
+ reason: t.Optional[str] = None,
+ icon_url: str = Icons.token_removed
+) -> bool:
+ """DM a user about their new infraction and return True if the DM is successful."""
+ embed = discord.Embed(
+ description=textwrap.dedent(f"""
+ **Type:** {infr_type.capitalize()}
+ **Expires:** {expires_at or "N/A"}
+ **Reason:** {reason or "No reason provided."}
+ """),
+ colour=Colours.soft_red
+ )
+
+ embed.set_author(name="Infraction Information", icon_url=icon_url, url=RULES_URL)
+ embed.title = f"Please review our rules over at {RULES_URL}"
+ embed.url = RULES_URL
+
+ if infr_type in APPEALABLE_INFRACTIONS:
+ embed.set_footer(
+ text="To appeal this infraction, send an e-mail to [email protected]"
+ )
+
+ return await send_private_embed(user, embed)
+
+
+async def notify_pardon(
+ user: UserTypes,
+ title: str,
+ content: str,
+ icon_url: str = Icons.user_verified
+) -> bool:
+ """DM a user about their pardoned infraction and return True if the DM is successful."""
+ embed = discord.Embed(
+ description=content,
+ colour=Colours.soft_green
+ )
+
+ embed.set_author(name=title, icon_url=icon_url)
+
+ return await send_private_embed(user, embed)
+
+
+async def send_private_embed(user: UserTypes, embed: discord.Embed) -> bool:
+ """
+ A helper method for sending an embed to a user's DMs.
+
+ Returns a boolean indicator of DM success.
+ """
+ try:
+ await user.send(embed=embed)
+ return True
+ except (discord.HTTPException, discord.Forbidden, discord.NotFound):
+ log.debug(
+ f"Infraction-related information could not be sent to user {user} ({user.id}). "
+ "The user either could not be retrieved or probably disabled their DMs."
+ )
+ return False
diff --git a/bot/cogs/off_topic_names.py b/bot/cogs/off_topic_names.py
index 16717d523..78792240f 100644
--- a/bot/cogs/off_topic_names.py
+++ b/bot/cogs/off_topic_names.py
@@ -24,6 +24,9 @@ class OffTopicName(Converter):
"""Attempt to replace any invalid characters with their approximate Unicode equivalent."""
allowed_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ!?'`-"
+ # Chain multiple words to a single one
+ argument = "-".join(argument.split())
+
if not (2 <= len(argument) <= 96):
raise BadArgument("Channel name must be between 2 and 96 chars long")
@@ -75,14 +78,16 @@ class OffTopicNames(Cog):
self.bot = bot
self.updater_task = None
+ self.bot.loop.create_task(self.init_offtopic_updater())
+
def cog_unload(self) -> None:
"""Cancel any running updater tasks on cog unload."""
if self.updater_task is not None:
self.updater_task.cancel()
- @Cog.listener()
- async def on_ready(self) -> None:
+ async def init_offtopic_updater(self) -> None:
"""Start off-topic channel updating event loop if it hasn't already started."""
+ await self.bot.wait_until_ready()
if self.updater_task is None:
coro = update_names(self.bot)
self.updater_task = self.bot.loop.create_task(coro)
@@ -95,30 +100,47 @@ class OffTopicNames(Cog):
@otname_group.command(name='add', aliases=('a',))
@with_role(*MODERATION_ROLES)
- async def add_command(self, ctx: Context, *names: OffTopicName) -> None:
- """Adds a new off-topic name to the rotation."""
- # Chain multiple words to a single one
- name = "-".join(names)
+ async def add_command(self, ctx: Context, *, name: OffTopicName) -> None:
+ """
+ Adds a new off-topic name to the rotation.
- await self.bot.api_client.post(f'bot/off-topic-channel-names', params={'name': name})
- log.info(
- f"{ctx.author.name}#{ctx.author.discriminator}"
- f" added the off-topic channel name '{name}"
- )
+ The name is not added if it is too similar to an existing name.
+ """
+ existing_names = await self.bot.api_client.get('bot/off-topic-channel-names')
+ close_match = difflib.get_close_matches(name, existing_names, n=1, cutoff=0.8)
+
+ if close_match:
+ match = close_match[0]
+ log.info(
+ f"{ctx.author} tried to add channel name '{name}' but it was too similar to '{match}'"
+ )
+ await ctx.send(
+ f":x: The channel name `{name}` is too similar to `{match}`, and thus was not added. "
+ "Use `!otn forceadd` to override this check."
+ )
+ else:
+ await self._add_name(ctx, name)
+
+ @otname_group.command(name='forceadd', aliases=('fa',))
+ @with_role(*MODERATION_ROLES)
+ async def force_add_command(self, ctx: Context, *, name: OffTopicName) -> None:
+ """Forcefully adds a new off-topic name to the rotation."""
+ await self._add_name(ctx, name)
+
+ async def _add_name(self, ctx: Context, name: str) -> None:
+ """Adds an off-topic channel name to the site storage."""
+ await self.bot.api_client.post('bot/off-topic-channel-names', params={'name': name})
+
+ log.info(f"{ctx.author} added the off-topic channel name '{name}'")
await ctx.send(f":ok_hand: Added `{name}` to the names list.")
@otname_group.command(name='delete', aliases=('remove', 'rm', 'del', 'd'))
@with_role(*MODERATION_ROLES)
- async def delete_command(self, ctx: Context, *names: OffTopicName) -> None:
+ async def delete_command(self, ctx: Context, *, name: OffTopicName) -> None:
"""Removes a off-topic name from the rotation."""
- # Chain multiple words to a single one
- name = "-".join(names)
-
await self.bot.api_client.delete(f'bot/off-topic-channel-names/{name}')
- log.info(
- f"{ctx.author.name}#{ctx.author.discriminator}"
- f" deleted the off-topic channel name '{name}"
- )
+
+ log.info(f"{ctx.author} deleted the off-topic channel name '{name}'")
await ctx.send(f":ok_hand: Removed `{name}` from the names list.")
@otname_group.command(name='list', aliases=('l',))
@@ -150,7 +172,7 @@ class OffTopicNames(Cog):
close_matches = difflib.get_close_matches(query, result, n=10, cutoff=0.70)
lines = sorted(f"• {name}" for name in in_matches.union(close_matches))
embed = Embed(
- title=f"Query results",
+ title="Query results",
colour=Colour.blue()
)
diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py
index 6880aab85..0d06e9c26 100644
--- a/bot/cogs/reddit.py
+++ b/bot/cogs/reddit.py
@@ -5,10 +5,11 @@ import textwrap
from datetime import datetime, timedelta
from typing import List
-from discord import Colour, Embed, Message, TextChannel
+from discord import Colour, Embed, TextChannel
from discord.ext.commands import Bot, Cog, Context, group
+from discord.ext.tasks import loop
-from bot.constants import Channels, ERROR_REPLIES, Reddit as RedditConfig, STAFF_ROLES
+from bot.constants import Channels, ERROR_REPLIES, Emojis, Reddit as RedditConfig, STAFF_ROLES, Webhooks
from bot.converters import Subreddit
from bot.decorators import with_role
from bot.pagination import LinePaginator
@@ -26,13 +27,25 @@ class Reddit(Cog):
def __init__(self, bot: Bot):
self.bot = bot
- self.reddit_channel = None
+ self.webhook = None # set in on_ready
+ bot.loop.create_task(self.init_reddit_ready())
- self.prev_lengths = {}
- self.last_ids = {}
+ self.auto_poster_loop.start()
- self.new_posts_task = None
- self.top_weekly_posts_task = None
+ def cog_unload(self) -> None:
+ """Stops the loops when the cog is unloaded."""
+ self.auto_poster_loop.cancel()
+
+ async def init_reddit_ready(self) -> None:
+ """Sets the reddit webhook when the cog is loaded."""
+ await self.bot.wait_until_ready()
+ if not self.webhook:
+ self.webhook = await self.bot.fetch_webhook(Webhooks.reddit)
+
+ @property
+ def channel(self) -> TextChannel:
+ """Get the #reddit channel object from the bot's cache."""
+ return self.bot.get_channel(Channels.reddit)
async def fetch_posts(self, route: str, *, amount: int = 25, params: dict = None) -> List[dict]:
"""A helper method to fetch a certain amount of Reddit posts at a given route."""
@@ -61,23 +74,22 @@ class Reddit(Cog):
log.debug(f"Invalid response from: {url} - status code {response.status}, mimetype {response.content_type}")
return list() # Failed to get appropriate response within allowed number of retries.
- async def send_top_posts(
- self, channel: TextChannel, subreddit: Subreddit, content: str = None, time: str = "all"
- ) -> Message:
- """Create an embed for the top posts, then send it in a given TextChannel."""
- # Create the new spicy embed.
- embed = Embed()
- embed.description = ""
-
- # Get the posts
- async with channel.typing():
- posts = await self.fetch_posts(
- route=f"{subreddit}/top",
- amount=5,
- params={
- "t": time
- }
- )
+ async def get_top_posts(self, subreddit: Subreddit, time: str = "all", amount: int = 5) -> Embed:
+ """
+ Get the top amount of posts for a given subreddit within a specified timeframe.
+
+ A time of "all" will get posts from all time, "day" will get top daily posts and "week" will get the top
+ weekly posts.
+
+ The amount should be between 0 and 25 as Reddit's JSON requests only provide 25 posts at most.
+ """
+ embed = Embed(description="")
+
+ posts = await self.fetch_posts(
+ route=f"{subreddit}/top",
+ amount=amount,
+ params={"t": time}
+ )
if not posts:
embed.title = random.choice(ERROR_REPLIES)
@@ -87,9 +99,7 @@ class Reddit(Cog):
"If this problem persists, please let us know."
)
- return await channel.send(
- embed=embed
- )
+ return embed
for post in posts:
data = post["data"]
@@ -107,109 +117,58 @@ class Reddit(Cog):
link = self.URL + data["permalink"]
embed.description += (
- f"[**{title}**]({link})\n"
+ f"**[{title}]({link})**\n"
f"{text}"
- f"| {ups} upvotes | {comments} comments | u/{author} | {subreddit} |\n\n"
+ f"{Emojis.upvotes} {ups} {Emojis.comments} {comments} {Emojis.user} {author}\n\n"
)
embed.colour = Colour.blurple()
+ return embed
- return await channel.send(
- content=content,
- embed=embed
- )
-
- async def poll_new_posts(self) -> None:
- """Periodically search for new subreddit posts."""
- while True:
- await asyncio.sleep(RedditConfig.request_delay)
-
- for subreddit in RedditConfig.subreddits:
- # Make a HEAD request to the subreddit
- head_response = await self.bot.http_session.head(
- url=f"{self.URL}/{subreddit}/new.rss",
- headers=self.HEADERS
- )
-
- content_length = head_response.headers["content-length"]
-
- # If the content is the same size as before, assume there's no new posts.
- if content_length == self.prev_lengths.get(subreddit, None):
- continue
-
- self.prev_lengths[subreddit] = content_length
+ @loop()
+ async def auto_poster_loop(self) -> None:
+ """Post the top 5 posts daily, and the top 5 posts weekly."""
+ # once we upgrade to d.py 1.3 this can be removed and the loop can use the `time=datetime.time.min` parameter
+ now = datetime.utcnow()
+ tomorrow = now + timedelta(days=1)
+ midnight_tomorrow = tomorrow.replace(hour=0, minute=0, second=0)
+ seconds_until = (midnight_tomorrow - now).total_seconds()
- # Now we can actually fetch the new data
- posts = await self.fetch_posts(f"{subreddit}/new")
- new_posts = []
+ await asyncio.sleep(seconds_until)
- # Only show new posts if we've checked before.
- if subreddit in self.last_ids:
- for post in posts:
- data = post["data"]
+ await self.bot.wait_until_ready()
+ if not self.webhook:
+ await self.bot.fetch_webhook(Webhooks.reddit)
- # Convert the ID to an integer for easy comparison.
- int_id = int(data["id"], 36)
+ if datetime.utcnow().weekday() == 0:
+ await self.top_weekly_posts()
+ # if it's a monday send the top weekly posts
- # If we've already seen this post, finish checking
- if int_id <= self.last_ids[subreddit]:
- break
+ for subreddit in RedditConfig.subreddits:
+ top_posts = await self.get_top_posts(subreddit=subreddit, time="day")
+ await self.webhook.send(username=f"{subreddit} Top Daily Posts", embed=top_posts)
- embed_data = {
- "title": textwrap.shorten(data["title"], width=64, placeholder="..."),
- "text": textwrap.shorten(data["selftext"], width=128, placeholder="..."),
- "url": self.URL + data["permalink"],
- "author": data["author"]
- }
+ async def top_weekly_posts(self) -> None:
+ """Post a summary of the top posts."""
+ for subreddit in RedditConfig.subreddits:
+ # Send and pin the new weekly posts.
+ top_posts = await self.get_top_posts(subreddit=subreddit, time="week")
- new_posts.append(embed_data)
+ message = await self.webhook.send(wait=True, username=f"{subreddit} Top Weekly Posts", embed=top_posts)
- self.last_ids[subreddit] = int(posts[0]["data"]["id"], 36)
+ if subreddit.lower() == "r/python":
+ if not self.channel:
+ log.warning("Failed to get #reddit channel to remove pins in the weekly loop.")
+ return
- # Send all of the new posts as spicy embeds
- for data in new_posts:
- embed = Embed()
+ # Remove the oldest pins so that only 12 remain at most.
+ pins = await self.channel.pins()
- embed.title = data["title"]
- embed.url = data["url"]
- embed.description = data["text"]
- embed.set_footer(text=f"Posted by u/{data['author']} in {subreddit}")
- embed.colour = Colour.blurple()
+ while len(pins) >= 12:
+ await pins[-1].unpin()
+ del pins[-1]
- await self.reddit_channel.send(embed=embed)
-
- log.trace(f"Sent {len(new_posts)} new {subreddit} posts to channel {self.reddit_channel.id}.")
-
- async def poll_top_weekly_posts(self) -> None:
- """Post a summary of the top posts every week."""
- while True:
- now = datetime.utcnow()
-
- # Calculate the amount of seconds until midnight next monday.
- monday = now + timedelta(days=7 - now.weekday())
- monday = monday.replace(hour=0, minute=0, second=0)
- until_monday = (monday - now).total_seconds()
-
- await asyncio.sleep(until_monday)
-
- for subreddit in RedditConfig.subreddits:
- # Send and pin the new weekly posts.
- message = await self.send_top_posts(
- channel=self.reddit_channel,
- subreddit=subreddit,
- content=f"This week's top {subreddit} posts have arrived!",
- time="week"
- )
-
- if subreddit.lower() == "r/python":
- # Remove the oldest pins so that only 5 remain at most.
- pins = await self.reddit_channel.pins()
-
- while len(pins) >= 5:
- await pins[-1].unpin()
- del pins[-1]
-
- await message.pin()
+ await message.pin()
@group(name="reddit", invoke_without_command=True)
async def reddit_group(self, ctx: Context) -> None:
@@ -219,32 +178,26 @@ class Reddit(Cog):
@reddit_group.command(name="top")
async def top_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None:
"""Send the top posts of all time from a given subreddit."""
- await self.send_top_posts(
- channel=ctx.channel,
- subreddit=subreddit,
- content=f"Here are the top {subreddit} posts of all time!",
- time="all"
- )
+ async with ctx.typing():
+ embed = await self.get_top_posts(subreddit=subreddit, time="all")
+
+ await ctx.send(content=f"Here are the top {subreddit} posts of all time!", embed=embed)
@reddit_group.command(name="daily")
async def daily_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None:
"""Send the top posts of today from a given subreddit."""
- await self.send_top_posts(
- channel=ctx.channel,
- subreddit=subreddit,
- content=f"Here are today's top {subreddit} posts!",
- time="day"
- )
+ async with ctx.typing():
+ embed = await self.get_top_posts(subreddit=subreddit, time="day")
+
+ await ctx.send(content=f"Here are today's top {subreddit} posts!", embed=embed)
@reddit_group.command(name="weekly")
async def weekly_command(self, ctx: Context, subreddit: Subreddit = "r/Python") -> None:
"""Send the top posts of this week from a given subreddit."""
- await self.send_top_posts(
- channel=ctx.channel,
- subreddit=subreddit,
- content=f"Here are this week's top {subreddit} posts!",
- time="week"
- )
+ async with ctx.typing():
+ embed = await self.get_top_posts(subreddit=subreddit, time="week")
+
+ await ctx.send(content=f"Here are this week's top {subreddit} posts!", embed=embed)
@with_role(*STAFF_ROLES)
@reddit_group.command(name="subreddits", aliases=("subs",))
@@ -262,19 +215,6 @@ class Reddit(Cog):
max_lines=15
)
- @Cog.listener()
- async def on_ready(self) -> None:
- """Initiate reddit post event loop."""
- self.reddit_channel = await self.bot.fetch_channel(Channels.reddit)
-
- if self.reddit_channel is not None:
- if self.new_posts_task is None:
- self.new_posts_task = self.bot.loop.create_task(self.poll_new_posts())
- if self.top_weekly_posts_task is None:
- self.top_weekly_posts_task = self.bot.loop.create_task(self.poll_top_weekly_posts())
- else:
- log.warning("Couldn't locate a channel for subreddit relaying.")
-
def setup(bot: Bot) -> None:
"""Reddit cog load."""
diff --git a/bot/cogs/reminders.py b/bot/cogs/reminders.py
index 6e91d2c06..81990704b 100644
--- a/bot/cogs/reminders.py
+++ b/bot/cogs/reminders.py
@@ -2,7 +2,7 @@ import asyncio
import logging
import random
import textwrap
-from datetime import datetime
+from datetime import datetime, timedelta
from operator import itemgetter
from typing import Optional
@@ -30,9 +30,11 @@ class Reminders(Scheduler, Cog):
self.bot = bot
super().__init__()
- @Cog.listener()
- async def on_ready(self) -> None:
+ self.bot.loop.create_task(self.reschedule_reminders())
+
+ async def reschedule_reminders(self) -> None:
"""Get all current reminders from the API and reschedule them."""
+ await self.bot.wait_until_ready()
response = await self.bot.api_client.get(
'bot/reminders',
params={'active': 'true'}
@@ -102,7 +104,10 @@ class Reminders(Scheduler, Cog):
name="It has arrived!"
)
- embed.description = f"Here's your reminder: `{reminder['content']}`"
+ embed.description = f"Here's your reminder: `{reminder['content']}`."
+
+ if reminder.get("jump_url"): # keep backward compatibility
+ embed.description += f"\n[Jump back to when you created the reminder]({reminder['jump_url']})"
if late:
embed.colour = Colour.red()
@@ -165,14 +170,18 @@ class Reminders(Scheduler, Cog):
json={
'author': ctx.author.id,
'channel_id': ctx.message.channel.id,
+ 'jump_url': ctx.message.jump_url,
'content': content,
'expiration': expiration.isoformat()
}
)
+ now = datetime.utcnow() - timedelta(seconds=1)
+
# Confirm to the user that it worked.
await self._send_confirmation(
- ctx, on_success="Your reminder has been created successfully!"
+ ctx,
+ on_success=f"Your reminder will arrive in {humanize_delta(relativedelta(expiration, now))}!"
)
loop = asyncio.get_event_loop()
diff --git a/bot/cogs/site.py b/bot/cogs/site.py
index c3bdf85e4..683613788 100644
--- a/bot/cogs/site.py
+++ b/bot/cogs/site.py
@@ -3,8 +3,7 @@ import logging
from discord import Colour, Embed
from discord.ext.commands import Bot, Cog, Context, group
-from bot.constants import Channels, STAFF_ROLES, URLs
-from bot.decorators import redirect_output
+from bot.constants import URLs
from bot.pagination import LinePaginator
log = logging.getLogger(__name__)
@@ -105,7 +104,6 @@ class Site(Cog):
await ctx.send(embed=embed)
@site_group.command(aliases=['r', 'rule'], name='rules')
- @redirect_output(destination_channel=Channels.bot, bypass_roles=STAFF_ROLES)
async def site_rules(self, ctx: Context, *rules: int) -> None:
"""Provides a link to all rules or, if specified, displays specific rule(s)."""
rules_embed = Embed(title='Rules', color=Colour.blurple())
@@ -126,15 +124,15 @@ class Site(Cog):
invalid_indices = tuple(
pick
for pick in rules
- if pick < 0 or pick >= len(full_rules)
+ if pick < 1 or pick > len(full_rules)
)
if invalid_indices:
indices = ', '.join(map(str, invalid_indices))
- await ctx.send(f":x: Invalid rule indices {indices}")
+ await ctx.send(f":x: Invalid rule indices: {indices}")
return
- final_rules = tuple(f"**{pick}.** {full_rules[pick]}" for pick in rules)
+ final_rules = tuple(f"**{pick}.** {full_rules[pick - 1]}" for pick in rules)
await LinePaginator.paginate(final_rules, ctx, rules_embed, max_lines=3)
diff --git a/bot/cogs/snekbox.py b/bot/cogs/snekbox.py
index 81185cf3e..362968bd0 100644
--- a/bot/cogs/snekbox.py
+++ b/bot/cogs/snekbox.py
@@ -115,6 +115,16 @@ class Snekbox(Cog):
return msg, error
+ @staticmethod
+ def get_status_emoji(results: dict) -> str:
+ """Return an emoji corresponding to the status code or lack of output in result."""
+ if not results["stdout"].strip(): # No output
+ return ":warning:"
+ elif results["returncode"] == 0: # No error
+ return ":white_check_mark:"
+ else: # Exception
+ return ":x:"
+
async def format_output(self, output: str) -> Tuple[str, Optional[str]]:
"""
Format the output and return a tuple of the formatted output and a URL to the full output.
@@ -178,7 +188,7 @@ class Snekbox(Cog):
if ctx.author.id in self.jobs:
await ctx.send(
f"{ctx.author.mention} You've already got a job running - "
- f"please wait for it to finish!"
+ "please wait for it to finish!"
)
return
@@ -186,10 +196,7 @@ class Snekbox(Cog):
await ctx.invoke(self.bot.get_command("help"), "eval")
return
- log.info(
- f"Received code from {ctx.author.name}#{ctx.author.discriminator} "
- f"for evaluation:\n{code}"
- )
+ log.info(f"Received code from {ctx.author} for evaluation:\n{code}")
self.jobs[ctx.author.id] = datetime.datetime.now()
code = self.prepare_input(code)
@@ -204,7 +211,8 @@ class Snekbox(Cog):
else:
output, paste_link = await self.format_output(results["stdout"])
- msg = f"{ctx.author.mention} {msg}.\n\n```py\n{output}\n```"
+ icon = self.get_status_emoji(results)
+ msg = f"{ctx.author.mention} {icon} {msg}.\n\n```py\n{output}\n```"
if paste_link:
msg = f"{msg}\nFull output: {paste_link}"
@@ -213,10 +221,7 @@ class Snekbox(Cog):
wait_for_deletion(response, user_ids=(ctx.author.id,), client=ctx.bot)
)
- log.info(
- f"{ctx.author.name}#{ctx.author.discriminator}'s job had a return code of "
- f"{results['returncode']}"
- )
+ log.info(f"{ctx.author}'s job had a return code of {results['returncode']}")
finally:
del self.jobs[ctx.author.id]
diff --git a/bot/cogs/superstarify/stars.py b/bot/cogs/superstarify/stars.py
deleted file mode 100644
index dbac86770..000000000
--- a/bot/cogs/superstarify/stars.py
+++ /dev/null
@@ -1,87 +0,0 @@
-import random
-
-
-STAR_NAMES = (
- "Adele",
- "Aerosmith",
- "Aretha Franklin",
- "Ayumi Hamasaki",
- "B'z",
- "Barbra Streisand",
- "Barry Manilow",
- "Barry White",
- "Beyonce",
- "Billy Joel",
- "Bob Dylan",
- "Bob Marley",
- "Bob Seger",
- "Bon Jovi",
- "Britney Spears",
- "Bruce Springsteen",
- "Bruno Mars",
- "Bryan Adams",
- "Celine Dion",
- "Cher",
- "Christina Aguilera",
- "David Bowie",
- "Donna Summer",
- "Drake",
- "Ed Sheeran",
- "Elton John",
- "Elvis Presley",
- "Eminem",
- "Enya",
- "Flo Rida",
- "Frank Sinatra",
- "Garth Brooks",
- "George Michael",
- "George Strait",
- "James Taylor",
- "Janet Jackson",
- "Jay-Z",
- "Johnny Cash",
- "Johnny Hallyday",
- "Julio Iglesias",
- "Justin Bieber",
- "Justin Timberlake",
- "Kanye West",
- "Katy Perry",
- "Kenny G",
- "Kenny Rogers",
- "Lady Gaga",
- "Lil Wayne",
- "Linda Ronstadt",
- "Lionel Richie",
- "Madonna",
- "Mariah Carey",
- "Meat Loaf",
- "Michael Jackson",
- "Neil Diamond",
- "Nicki Minaj",
- "Olivia Newton-John",
- "Paul McCartney",
- "Phil Collins",
- "Pink",
- "Prince",
- "Reba McEntire",
- "Rihanna",
- "Robbie Williams",
- "Rod Stewart",
- "Santana",
- "Shania Twain",
- "Stevie Wonder",
- "Taylor Swift",
- "Tim McGraw",
- "Tina Turner",
- "Tom Petty",
- "Tupac Shakur",
- "Usher",
- "Van Halen",
- "Whitney Houston",
-)
-
-
-def get_nick(infraction_id: int, member_id: int) -> str:
- """Randomly select a nickname from the Superstarify nickname list."""
- rng = random.Random(str(infraction_id) + str(member_id))
- return rng.choice(STAR_NAMES)
diff --git a/bot/cogs/sync/cog.py b/bot/cogs/sync/cog.py
index b75fb26cd..aaa581f96 100644
--- a/bot/cogs/sync/cog.py
+++ b/bot/cogs/sync/cog.py
@@ -29,9 +29,11 @@ class Sync(Cog):
def __init__(self, bot: Bot) -> None:
self.bot = bot
- @Cog.listener()
- async def on_ready(self) -> None:
+ self.bot.loop.create_task(self.sync_guild())
+
+ async def sync_guild(self) -> None:
"""Syncs the roles/users of the guild with the database."""
+ await self.bot.wait_until_ready()
guild = self.bot.get_guild(self.SYNC_SERVER_ID)
if guild is not None:
for syncer in self.ON_READY_SYNCERS:
diff --git a/bot/cogs/token_remover.py b/bot/cogs/token_remover.py
index e5b0e5b45..11f4bc896 100644
--- a/bot/cogs/token_remover.py
+++ b/bot/cogs/token_remover.py
@@ -9,7 +9,7 @@ from discord import Colour, Message
from discord.ext.commands import Bot, Cog
from discord.utils import snowflake_time
-from bot.cogs.modlog import ModLog
+from bot.cogs.moderation import ModLog
from bot.constants import Channels, Colours, Event, Icons
log = logging.getLogger(__name__)
@@ -26,11 +26,11 @@ DELETION_MESSAGE_TEMPLATE = (
DISCORD_EPOCH_TIMESTAMP = datetime(2017, 1, 1)
TOKEN_EPOCH = 1_293_840_000
TOKEN_RE = re.compile(
- r"[^\s\.]+" # Matches token part 1: The user ID string, encoded as base64
- r"\." # Matches a literal dot between the token parts
- r"[^\s\.]+" # Matches token part 2: The creation timestamp, as an integer
- r"\." # Matches a literal dot between the token parts
- r"[^\s\.]+" # Matches token part 3: The HMAC, unused by us, but check that it isn't empty
+ r"[^\s\.()\"']+" # Matches token part 1: The user ID string, encoded as base64
+ r"\." # Matches a literal dot between the token parts
+ r"[^\s\.()\"']+" # Matches token part 2: The creation timestamp, as an integer
+ r"\." # Matches a literal dot between the token parts
+ r"[^\s\.()\"']+" # Matches token part 3: The HMAC, unused by us, but check that it isn't empty
)
diff --git a/bot/cogs/utils.py b/bot/cogs/utils.py
index b6cecdc7c..793fe4c1a 100644
--- a/bot/cogs/utils.py
+++ b/bot/cogs/utils.py
@@ -1,15 +1,18 @@
import logging
import re
import unicodedata
+from asyncio import TimeoutError, sleep
from email.parser import HeaderParser
from io import StringIO
from typing import Tuple
-from discord import Colour, Embed
+from dateutil import relativedelta
+from discord import Colour, Embed, Message, Role
from discord.ext.commands import Bot, Cog, Context, command
-from bot.constants import Channels, STAFF_ROLES
-from bot.decorators import in_channel
+from bot.constants import Channels, MODERATION_ROLES, Mention, STAFF_ROLES
+from bot.decorators import in_channel, with_role
+from bot.utils.time import humanize_delta
log = logging.getLogger(__name__)
@@ -32,56 +35,58 @@ class Utils(Cog):
await ctx.invoke(self.bot.get_command("help"), "pep")
return
- # Newer PEPs are written in RST instead of txt
- if pep_number > 542:
- pep_url = f"{self.base_github_pep_url}{pep_number:04}.rst"
- else:
- pep_url = f"{self.base_github_pep_url}{pep_number:04}.txt"
-
- # Attempt to fetch the PEP
- log.trace(f"Requesting PEP {pep_number} with {pep_url}")
- response = await self.bot.http_session.get(pep_url)
-
- if response.status == 200:
- log.trace("PEP found")
+ possible_extensions = ['.txt', '.rst']
+ found_pep = False
+ for extension in possible_extensions:
+ # Attempt to fetch the PEP
+ pep_url = f"{self.base_github_pep_url}{pep_number:04}{extension}"
+ log.trace(f"Requesting PEP {pep_number} with {pep_url}")
+ response = await self.bot.http_session.get(pep_url)
- pep_content = await response.text()
+ if response.status == 200:
+ log.trace("PEP found")
+ found_pep = True
- # Taken from https://github.com/python/peps/blob/master/pep0/pep.py#L179
- pep_header = HeaderParser().parse(StringIO(pep_content))
+ pep_content = await response.text()
- # Assemble the embed
- pep_embed = Embed(
- title=f"**PEP {pep_number} - {pep_header['Title']}**",
- description=f"[Link]({self.base_pep_url}{pep_number:04})",
- )
-
- pep_embed.set_thumbnail(url="https://www.python.org/static/opengraph-icon-200x200.png")
+ # Taken from https://github.com/python/peps/blob/master/pep0/pep.py#L179
+ pep_header = HeaderParser().parse(StringIO(pep_content))
- # Add the interesting information
- if "Status" in pep_header:
- pep_embed.add_field(name="Status", value=pep_header["Status"])
- if "Python-Version" in pep_header:
- pep_embed.add_field(name="Python-Version", value=pep_header["Python-Version"])
- if "Created" in pep_header:
- pep_embed.add_field(name="Created", value=pep_header["Created"])
- if "Type" in pep_header:
- pep_embed.add_field(name="Type", value=pep_header["Type"])
+ # Assemble the embed
+ pep_embed = Embed(
+ title=f"**PEP {pep_number} - {pep_header['Title']}**",
+ description=f"[Link]({self.base_pep_url}{pep_number:04})",
+ )
- elif response.status == 404:
+ pep_embed.set_thumbnail(url="https://www.python.org/static/opengraph-icon-200x200.png")
+
+ # Add the interesting information
+ if "Status" in pep_header:
+ pep_embed.add_field(name="Status", value=pep_header["Status"])
+ if "Python-Version" in pep_header:
+ pep_embed.add_field(name="Python-Version", value=pep_header["Python-Version"])
+ if "Created" in pep_header:
+ pep_embed.add_field(name="Created", value=pep_header["Created"])
+ if "Type" in pep_header:
+ pep_embed.add_field(name="Type", value=pep_header["Type"])
+
+ elif response.status != 404:
+ # any response except 200 and 404 is expected
+ found_pep = True # actually not, but it's easier to display this way
+ log.trace(f"The user requested PEP {pep_number}, but the response had an unexpected status code: "
+ f"{response.status}.\n{response.text}")
+
+ error_message = "Unexpected HTTP error during PEP search. Please let us know."
+ pep_embed = Embed(title="Unexpected error", description=error_message)
+ pep_embed.colour = Colour.red()
+ break
+
+ if not found_pep:
log.trace("PEP was not found")
not_found = f"PEP {pep_number} does not exist."
pep_embed = Embed(title="PEP not found", description=not_found)
pep_embed.colour = Colour.red()
- else:
- log.trace(f"The user requested PEP {pep_number}, but the response had an unexpected status code: "
- f"{response.status}.\n{response.text}")
-
- error_message = "Unexpected HTTP error during PEP search. Please let us know."
- pep_embed = Embed(title="Unexpected error", description=error_message)
- pep_embed.colour = Colour.red()
-
await ctx.message.channel.send(embed=pep_embed)
@command()
@@ -128,6 +133,47 @@ class Utils(Cog):
await ctx.send(embed=embed)
+ @command()
+ @with_role(*MODERATION_ROLES)
+ async def mention(self, ctx: Context, *, role: Role) -> None:
+ """Set a role to be mentionable for a limited time."""
+ if role.mentionable:
+ await ctx.send(f"{role} is already mentionable!")
+ return
+
+ await role.edit(reason=f"Role unlocked by {ctx.author}", mentionable=True)
+
+ human_time = humanize_delta(relativedelta.relativedelta(seconds=Mention.message_timeout))
+ await ctx.send(
+ f"{role} has been made mentionable. I will reset it in {human_time}, or when someone mentions this role."
+ )
+
+ def check(m: Message) -> bool:
+ """Checks that the message contains the role mention."""
+ return role in m.role_mentions
+
+ try:
+ msg = await self.bot.wait_for("message", check=check, timeout=Mention.message_timeout)
+ except TimeoutError:
+ await role.edit(mentionable=False, reason="Automatic role lock - timeout.")
+ await ctx.send(f"{ctx.author.mention}, you took too long. I have reset {role} to be unmentionable.")
+ return
+
+ if any(r.id in MODERATION_ROLES for r in msg.author.roles):
+ await sleep(Mention.reset_delay)
+ await role.edit(mentionable=False, reason=f"Automatic role lock by {msg.author}")
+ await ctx.send(
+ f"{ctx.author.mention}, I have reset {role} to be unmentionable as "
+ f"{msg.author if msg.author != ctx.author else 'you'} sent a message mentioning it."
+ )
+ return
+
+ await role.edit(mentionable=False, reason=f"Automatic role lock - unauthorised use by {msg.author}")
+ await ctx.send(
+ f"{ctx.author.mention}, I have reset {role} to be unmentionable "
+ f"as I detected unauthorised use by {msg.author} (ID: {msg.author.id})."
+ )
+
def setup(bot: Bot) -> None:
"""Utils cog load."""
diff --git a/bot/cogs/verification.py b/bot/cogs/verification.py
index f0a099f27..5b115deaa 100644
--- a/bot/cogs/verification.py
+++ b/bot/cogs/verification.py
@@ -1,10 +1,12 @@
import logging
+from datetime import datetime
from discord import Message, NotFound, Object
+from discord.ext import tasks
from discord.ext.commands import Bot, Cog, Context, command
-from bot.cogs.modlog import ModLog
-from bot.constants import Channels, Event, Roles
+from bot.cogs.moderation import ModLog
+from bot.constants import Bot as BotConfig, Channels, Event, Roles
from bot.decorators import InChannelCheckFailure, in_channel, without_role
log = logging.getLogger(__name__)
@@ -27,12 +29,18 @@ from time to time, you can send `!subscribe` to <#{Channels.bot}> at any time to
If you'd like to unsubscribe from the announcement notifications, simply send `!unsubscribe` to <#{Channels.bot}>.
"""
+PERIODIC_PING = (
+ f"@everyone To verify that you have read our rules, please type `{BotConfig.prefix}accept`."
+ f" Ping <@&{Roles.admin}> if you encounter any problems during the verification process."
+)
+
class Verification(Cog):
"""User verification and role self-management."""
def __init__(self, bot: Bot):
self.bot = bot
+ self.periodic_ping.start()
@property
def mod_log(self) -> ModLog:
@@ -155,6 +163,34 @@ class Verification(Cog):
else:
return True
+ @tasks.loop(hours=12)
+ async def periodic_ping(self) -> None:
+ """Every week, mention @everyone to remind them to verify."""
+ messages = self.bot.get_channel(Channels.verification).history(limit=10)
+ need_to_post = True # True if a new message needs to be sent.
+
+ async for message in messages:
+ if message.author == self.bot.user and message.content == PERIODIC_PING:
+ delta = datetime.utcnow() - message.created_at # Time since last message.
+ if delta.days >= 7: # Message is older than a week.
+ await message.delete()
+ else:
+ need_to_post = False
+
+ break
+
+ if need_to_post:
+ await self.bot.get_channel(Channels.verification).send(PERIODIC_PING)
+
+ @periodic_ping.before_loop
+ async def before_ping(self) -> None:
+ """Only start the loop when the bot is ready."""
+ await self.bot.wait_until_ready()
+
+ def cog_unload(self) -> None:
+ """Cancel the periodic ping task when the cog is unloaded."""
+ self.periodic_ping.cancel()
+
def setup(bot: Bot) -> None:
"""Verification cog load."""
diff --git a/bot/cogs/watchchannels/bigbrother.py b/bot/cogs/watchchannels/bigbrother.py
index e191c2dbc..c516508ca 100644
--- a/bot/cogs/watchchannels/bigbrother.py
+++ b/bot/cogs/watchchannels/bigbrother.py
@@ -5,9 +5,9 @@ from typing import Union
from discord import User
from discord.ext.commands import Bot, Cog, Context, group
+from bot.cogs.moderation.utils import post_infraction
from bot.constants import Channels, Roles, Webhooks
from bot.decorators import with_role
-from bot.utils.moderation import post_infraction
from .watchchannel import WatchChannel, proxy_user
log = logging.getLogger(__name__)
@@ -64,13 +64,31 @@ class BigBrother(WatchChannel, Cog, name="Big Brother"):
await ctx.send(":x: The specified user is already being watched.")
return
- response = await post_infraction(
- ctx, user, type='watch', reason=reason, hidden=True
- )
+ response = await post_infraction(ctx, user, 'watch', reason, hidden=True)
if response is not None:
self.watched_users[user.id] = response
- await ctx.send(f":white_check_mark: Messages sent by {user} will now be relayed to Big Brother.")
+ msg = f":white_check_mark: Messages sent by {user} will now be relayed to Big Brother."
+
+ history = await self.bot.api_client.get(
+ self.api_endpoint,
+ params={
+ "user__id": str(user.id),
+ "active": "false",
+ 'type': 'watch',
+ 'ordering': '-inserted_at'
+ }
+ )
+
+ if len(history) > 1:
+ total = f"({len(history) // 2} previous infractions in total)"
+ end_reason = history[0]["reason"]
+ start_reason = f"Watched: {history[1]['reason']}"
+ msg += f"\n\nUser's previous watch reasons {total}:```{start_reason}\n\n{end_reason}```"
+ else:
+ msg = ":x: Failed to post the infraction: response was empty."
+
+ await ctx.send(msg)
@bigbrother_group.command(name='unwatch', aliases=('uw',))
@with_role(Roles.owner, Roles.admin, Roles.moderator)
@@ -91,7 +109,7 @@ class BigBrother(WatchChannel, Cog, name="Big Brother"):
json={'active': False}
)
- await post_infraction(ctx, user, type='watch', reason=f"Unwatched: {reason}", hidden=True, active=False)
+ await post_infraction(ctx, user, 'watch', f"Unwatched: {reason}", hidden=True, active=False)
await ctx.send(f":white_check_mark: Messages sent by {user} will no longer be relayed.")
diff --git a/bot/cogs/watchchannels/talentpool.py b/bot/cogs/watchchannels/talentpool.py
index 4a23902d5..176c6f760 100644
--- a/bot/cogs/watchchannels/talentpool.py
+++ b/bot/cogs/watchchannels/talentpool.py
@@ -93,7 +93,24 @@ class TalentPool(WatchChannel, Cog, name="Talentpool"):
resp.raise_for_status()
self.watched_users[user.id] = response_data
- await ctx.send(f":white_check_mark: Messages sent by {user} will now be relayed to the talent pool channel")
+ msg = f":white_check_mark: Messages sent by {user} will now be relayed to the talent pool channel"
+
+ history = await self.bot.api_client.get(
+ self.api_endpoint,
+ params={
+ "user__id": str(user.id),
+ "active": "false",
+ "ordering": "-inserted_at"
+ }
+ )
+
+ if history:
+ total = f"({len(history)} previous nominations in total)"
+ start_reason = f"Watched: {history[0]['reason']}"
+ end_reason = f"Unwatched: {history[0]['end_reason']}"
+ msg += f"\n\nUser's previous watch reasons {total}:```{start_reason}\n\n{end_reason}```"
+
+ await ctx.send(msg)
@nomination_group.command(name='history', aliases=('info', 'search'))
@with_role(Roles.owner, Roles.admin, Roles.moderator)
diff --git a/bot/cogs/watchchannels/watchchannel.py b/bot/cogs/watchchannels/watchchannel.py
index 760e012eb..0bf75a924 100644
--- a/bot/cogs/watchchannels/watchchannel.py
+++ b/bot/cogs/watchchannels/watchchannel.py
@@ -13,7 +13,7 @@ from discord import Color, Embed, HTTPException, Message, Object, errors
from discord.ext.commands import BadArgument, Bot, Cog, Context
from bot.api import ResponseCodeError
-from bot.cogs.modlog import ModLog
+from bot.cogs.moderation import ModLog
from bot.constants import BigBrother as BigBrotherConfig, Guild as GuildConfig, Icons
from bot.pagination import LinePaginator
from bot.utils import CogABCMeta, messages
diff --git a/bot/constants.py b/bot/constants.py
index 1deeaa3b8..d3e79b4c2 100644
--- a/bot/constants.py
+++ b/bot/constants.py
@@ -259,6 +259,10 @@ class Emojis(metaclass=YAMLGetter):
pencil: str
cross_mark: str
+ upvotes: str
+ comments: str
+ user: str
+
class Icons(metaclass=YAMLGetter):
section = "style"
@@ -328,6 +332,7 @@ class Channels(metaclass=YAMLGetter):
subsection = "channels"
admins: int
+ admin_spam: int
announcements: int
big_brother_logs: int
bot: int
@@ -345,11 +350,15 @@ class Channels(metaclass=YAMLGetter):
help_7: int
helpers: int
message_log: int
+ meta: int
+ mod_spam: int
+ mods: int
mod_alerts: int
modlog: int
off_topic_0: int
off_topic_1: int
off_topic_2: int
+ organisation: int
python: int
reddit: int
talent_pool: int
@@ -364,6 +373,7 @@ class Webhooks(metaclass=YAMLGetter):
talent_pool: int
big_brother: int
+ reddit: int
class Roles(metaclass=YAMLGetter):
@@ -391,6 +401,7 @@ class Guild(metaclass=YAMLGetter):
id: int
ignored: List[int]
+ staff_channels: List[int]
class Keys(metaclass=YAMLGetter):
@@ -438,7 +449,6 @@ class URLs(metaclass=YAMLGetter):
class Reddit(metaclass=YAMLGetter):
section = "reddit"
- request_delay: int
subreddits: list
@@ -460,6 +470,12 @@ class AntiSpam(metaclass=YAMLGetter):
rules: Dict[str, Dict[str, int]]
+class AntiMalware(metaclass=YAMLGetter):
+ section = "anti_malware"
+
+ whitelist: list
+
+
class BigBrother(metaclass=YAMLGetter):
section = 'big_brother'
@@ -475,6 +491,13 @@ class Free(metaclass=YAMLGetter):
cooldown_per: float
+class Mention(metaclass=YAMLGetter):
+ section = 'mention'
+
+ message_timeout: int
+ reset_delay: int
+
+
class RedirectOutput(metaclass=YAMLGetter):
section = 'redirect_output'
@@ -493,6 +516,12 @@ PROJECT_ROOT = os.path.abspath(os.path.join(BOT_DIR, os.pardir))
MODERATION_ROLES = Roles.moderator, Roles.admin, Roles.owner
STAFF_ROLES = Roles.helpers, Roles.moderator, Roles.admin, Roles.owner
+# Roles combinations
+STAFF_CHANNELS = Guild.staff_channels
+
+# Default Channel combinations
+MODERATION_CHANNELS = Channels.admins, Channels.admin_spam, Channels.mod_alerts, Channels.mods, Channels.mod_spam
+
# Bot replies
NEGATIVE_REPLIES = [
diff --git a/bot/converters.py b/bot/converters.py
index 6d6453486..cf0496541 100644
--- a/bot/converters.py
+++ b/bot/converters.py
@@ -4,6 +4,8 @@ from datetime import datetime
from ssl import CertificateError
from typing import Union
+import dateutil.parser
+import dateutil.tz
import discord
from aiohttp import ClientConnectorError
from dateutil.relativedelta import relativedelta
@@ -215,3 +217,45 @@ class Duration(Converter):
now = datetime.utcnow()
return now + delta
+
+
+class ISODateTime(Converter):
+ """Converts an ISO-8601 datetime string into a datetime.datetime."""
+
+ async def convert(self, ctx: Context, datetime_string: str) -> datetime:
+ """
+ Converts a ISO-8601 `datetime_string` into a `datetime.datetime` object.
+
+ The converter is flexible in the formats it accepts, as it uses the `isoparse` method of
+ `dateutil.parser`. In general, it accepts datetime strings that start with a date,
+ optionally followed by a time. Specifying a timezone offset in the datetime string is
+ supported, but the `datetime` object will be converted to UTC and will be returned without
+ `tzinfo` as a timezone-unaware `datetime` object.
+
+ See: https://dateutil.readthedocs.io/en/stable/parser.html#dateutil.parser.isoparse
+
+ Formats that are guaranteed to be valid by our tests are:
+
+ - `YYYY-mm-ddTHH:MM:SSZ` | `YYYY-mm-dd HH:MM:SSZ`
+ - `YYYY-mm-ddTHH:MM:SS±HH:MM` | `YYYY-mm-dd HH:MM:SS±HH:MM`
+ - `YYYY-mm-ddTHH:MM:SS±HHMM` | `YYYY-mm-dd HH:MM:SS±HHMM`
+ - `YYYY-mm-ddTHH:MM:SS±HH` | `YYYY-mm-dd HH:MM:SS±HH`
+ - `YYYY-mm-ddTHH:MM:SS` | `YYYY-mm-dd HH:MM:SS`
+ - `YYYY-mm-ddTHH:MM` | `YYYY-mm-dd HH:MM`
+ - `YYYY-mm-dd`
+ - `YYYY-mm`
+ - `YYYY`
+
+ Note: ISO-8601 specifies a `T` as the separator between the date and the time part of the
+ datetime string. The converter accepts both a `T` and a single space character.
+ """
+ try:
+ dt = dateutil.parser.isoparse(datetime_string)
+ except ValueError:
+ raise BadArgument(f"`{datetime_string}` is not a valid ISO-8601 datetime string")
+
+ if dt.tzinfo:
+ dt = dt.astimezone(dateutil.tz.UTC)
+ dt = dt.replace(tzinfo=None)
+
+ return dt
diff --git a/bot/decorators.py b/bot/decorators.py
index 33a6bcadd..935df4af0 100644
--- a/bot/decorators.py
+++ b/bot/decorators.py
@@ -3,13 +3,13 @@ import random
from asyncio import Lock, sleep
from contextlib import suppress
from functools import wraps
-from typing import Any, Callable, Container, Optional
+from typing import Callable, Container, Union
from weakref import WeakValueDictionary
-from discord import Colour, Embed
+from discord import Colour, Embed, Member
from discord.errors import NotFound
from discord.ext import commands
-from discord.ext.commands import CheckFailure, Context
+from discord.ext.commands import CheckFailure, Cog, Context
from bot.constants import ERROR_REPLIES, RedirectOutput
from bot.utils.checks import with_role_check, without_role_check
@@ -72,13 +72,13 @@ def locked() -> Callable:
Subsequent calls to the command from the same author are ignored until the command has completed invocation.
- This decorator has to go before (below) the `command` decorator.
+ This decorator must go before (below) the `command` decorator.
"""
def wrap(func: Callable) -> Callable:
func.__locks = WeakValueDictionary()
@wraps(func)
- async def inner(self: Callable, ctx: Context, *args, **kwargs) -> Optional[Any]:
+ async def inner(self: Cog, ctx: Context, *args, **kwargs) -> None:
lock = func.__locks.setdefault(ctx.author.id, Lock())
if lock.locked():
embed = Embed()
@@ -93,7 +93,7 @@ def locked() -> Callable:
return
async with func.__locks.setdefault(ctx.author.id, Lock()):
- return await func(self, ctx, *args, **kwargs)
+ await func(self, ctx, *args, **kwargs)
return inner
return wrap
@@ -103,17 +103,21 @@ def redirect_output(destination_channel: int, bypass_roles: Container[int] = Non
Changes the channel in the context of the command to redirect the output to a certain channel.
Redirect is bypassed if the author has a role to bypass redirection.
+
+ This decorator must go before (below) the `command` decorator.
"""
def wrap(func: Callable) -> Callable:
@wraps(func)
- async def inner(self: Callable, ctx: Context, *args, **kwargs) -> Any:
+ async def inner(self: Cog, ctx: Context, *args, **kwargs) -> None:
if ctx.channel.id == destination_channel:
log.trace(f"Command {ctx.command.name} was invoked in destination_channel, not redirecting")
- return await func(self, ctx, *args, **kwargs)
+ await func(self, ctx, *args, **kwargs)
+ return
if bypass_roles and any(role.id in bypass_roles for role in ctx.author.roles):
log.trace(f"{ctx.author} has role to bypass output redirection")
- return await func(self, ctx, *args, **kwargs)
+ await func(self, ctx, *args, **kwargs)
+ return
redirect_channel = ctx.guild.get_channel(destination_channel)
old_channel = ctx.channel
@@ -140,3 +144,50 @@ def redirect_output(destination_channel: int, bypass_roles: Container[int] = Non
log.trace("Redirect output: Deleted invocation message")
return inner
return wrap
+
+
+def respect_role_hierarchy(target_arg: Union[int, str] = 0) -> Callable:
+ """
+ Ensure the highest role of the invoking member is greater than that of the target member.
+
+ If the condition fails, a warning is sent to the invoking context. A target which is not an
+ instance of discord.Member will always pass.
+
+ A value of 0 (i.e. position 0) for `target_arg` corresponds to the argument which comes after
+ `ctx`. If the target argument is a kwarg, its name can instead be given.
+
+ This decorator must go before (below) the `command` decorator.
+ """
+ def wrap(func: Callable) -> Callable:
+ @wraps(func)
+ async def inner(self: Cog, ctx: Context, *args, **kwargs) -> None:
+ try:
+ target = kwargs[target_arg]
+ except KeyError:
+ try:
+ target = args[target_arg]
+ except IndexError:
+ raise ValueError(f"Could not find target argument at position {target_arg}")
+ except TypeError:
+ raise ValueError(f"Could not find target kwarg with key {target_arg!r}")
+
+ if not isinstance(target, Member):
+ log.trace("The target is not a discord.Member; skipping role hierarchy check.")
+ await func(self, ctx, *args, **kwargs)
+ return
+
+ cmd = ctx.command.name
+ actor = ctx.author
+ if target.top_role >= actor.top_role:
+ log.info(
+ f"{actor} ({actor.id}) attempted to {cmd} "
+ f"{target} ({target.id}), who has an equal or higher top role."
+ )
+ await ctx.send(
+ f":x: {actor.mention}, you may not {cmd} "
+ "someone with an equal or higher top role."
+ )
+ else:
+ await func(self, ctx, *args, **kwargs)
+ return inner
+ return wrap
diff --git a/bot/resources/stars.json b/bot/resources/stars.json
index 8071b9626..c0b253120 100644
--- a/bot/resources/stars.json
+++ b/bot/resources/stars.json
@@ -1,82 +1,78 @@
-{
- "Adele": "https://upload.wikimedia.org/wikipedia/commons/thumb/7/7c/Adele_2016.jpg/220px-Adele_2016.jpg",
- "Steven Tyler": "https://upload.wikimedia.org/wikipedia/commons/thumb/a/a8/Steven_Tyler_by_Gage_Skidmore_3.jpg/220px-Steven_Tyler_by_Gage_Skidmore_3.jpg",
- "Alex Van Halen": "https://upload.wikimedia.org/wikipedia/commons/thumb/b/b3/Alex_Van_Halen_-_Van_Halen_Live.jpg/220px-Alex_Van_Halen_-_Van_Halen_Live.jpg",
- "Aretha Franklin": "https://upload.wikimedia.org/wikipedia/commons/thumb/c/c6/Aretha_Franklin_1968.jpg/220px-Aretha_Franklin_1968.jpg",
- "Ayumi Hamasaki": "https://upload.wikimedia.org/wikipedia/commons/thumb/5/50/Ayumi_Hamasaki_2007.jpg/220px-Ayumi_Hamasaki_2007.jpg",
- "Koshi Inaba": "https://upload.wikimedia.org/wikipedia/commons/thumb/a/af/B%27Z_at_Best_Buy_Theater_NYC_-_9-30-12_-_18.jpg/220px-B%27Z_at_Best_Buy_Theater_NYC_-_9-30-12_-_18.jpg",
- "Barbra Streisand": "https://upload.wikimedia.org/wikipedia/en/thumb/a/a3/Barbra_Streisand_-_1966.jpg/220px-Barbra_Streisand_-_1966.jpg",
- "Barry Manilow": "https://upload.wikimedia.org/wikipedia/commons/thumb/2/2b/BarryManilow.jpg/220px-BarryManilow.jpg",
- "Barry White": "https://upload.wikimedia.org/wikipedia/commons/thumb/b/b7/Barry_White%2C_Bestanddeelnr_927-0099.jpg/220px-Barry_White%2C_Bestanddeelnr_927-0099.jpg",
- "Beyonce": "https://upload.wikimedia.org/wikipedia/commons/thumb/f/f2/Beyonce_-_The_Formation_World_Tour%2C_at_Wembley_Stadium_in_London%2C_England.jpg/220px-Beyonce_-_The_Formation_World_Tour%2C_at_Wembley_Stadium_in_London%2C_England.jpg",
- "Billy Joel": "https://upload.wikimedia.org/wikipedia/commons/thumb/1/19/Billy_Joel_Shankbone_NYC_2009.jpg/220px-Billy_Joel_Shankbone_NYC_2009.jpg",
- "Bob Dylan": "https://upload.wikimedia.org/wikipedia/commons/thumb/0/02/Bob_Dylan_-_Azkena_Rock_Festival_2010_2.jpg/220px-Bob_Dylan_-_Azkena_Rock_Festival_2010_2.jpg",
- "Bob Marley": "https://upload.wikimedia.org/wikipedia/commons/thumb/5/5e/Bob-Marley.jpg/220px-Bob-Marley.jpg",
- "Bob Seger": "https://upload.wikimedia.org/wikipedia/commons/thumb/1/16/Bob_Seger_2013.jpg/220px-Bob_Seger_2013.jpg",
- "Jon Bon Jovi": "https://upload.wikimedia.org/wikipedia/commons/thumb/4/49/Jon_Bon_Jovi_at_the_2009_Tribeca_Film_Festival_3.jpg/220px-Jon_Bon_Jovi_at_the_2009_Tribeca_Film_Festival_3.jpg",
- "Britney Spears": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/da/Britney_Spears_2013_%28Straighten_Crop%29.jpg/200px-Britney_Spears_2013_%28Straighten_Crop%29.jpg",
- "Bruce Springsteen": "https://upload.wikimedia.org/wikipedia/commons/thumb/3/3b/Bruce_Springsteen_-_Roskilde_Festival_2012.jpg/210px-Bruce_Springsteen_-_Roskilde_Festival_2012.jpg",
- "Bruno Mars": "https://upload.wikimedia.org/wikipedia/commons/thumb/b/b0/BrunoMars24KMagicWorldTourLive_%28cropped%29.jpg/220px-BrunoMars24KMagicWorldTourLive_%28cropped%29.jpg",
- "Bryan Adams": "https://upload.wikimedia.org/wikipedia/commons/thumb/7/7e/Bryan_Adams_Hamburg_MG_0631_flickr.jpg/300px-Bryan_Adams_Hamburg_MG_0631_flickr.jpg",
- "Celine Dion": "https://upload.wikimedia.org/wikipedia/commons/thumb/4/42/Celine_Dion_Concert_Singing_Taking_Chances_2008.jpg/220px-Celine_Dion_Concert_Singing_Taking_Chances_2008.jpg",
- "Cher": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Cher_-_Casablanca.jpg/220px-Cher_-_Casablanca.jpg",
- "Christina Aguilera": "https://upload.wikimedia.org/wikipedia/commons/thumb/e/e7/Christina_Aguilera_in_2016.jpg/220px-Christina_Aguilera_in_2016.jpg",
- "David Bowie": "https://upload.wikimedia.org/wikipedia/commons/thumb/e/e8/David-Bowie_Chicago_2002-08-08_photoby_Adam-Bielawski-cropped.jpg/220px-David-Bowie_Chicago_2002-08-08_photoby_Adam-Bielawski-cropped.jpg",
- "David Lee Roth": "https://upload.wikimedia.org/wikipedia/commons/thumb/f/fb/David_Lee_Roth_-_Van_Halen.jpg/220px-David_Lee_Roth_-_Van_Halen.jpg",
- "Donna Summer": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Nobel_Peace_Price_Concert_2009_Donna_Summer3.jpg/220px-Nobel_Peace_Price_Concert_2009_Donna_Summer3.jpg",
- "Drake": "https://upload.wikimedia.org/wikipedia/commons/thumb/8/81/Drake_at_the_Velvet_Underground_-_2017_%2835986086223%29_%28cropped%29.jpg/220px-Drake_at_the_Velvet_Underground_-_2017_%2835986086223%29_%28cropped%29.jpg",
- "Ed Sheeran": "https://upload.wikimedia.org/wikipedia/commons/thumb/5/55/Ed_Sheeran_2013.jpg/220px-Ed_Sheeran_2013.jpg",
- "Eddie Van Halen": "https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Eddie_Van_Halen.jpg/300px-Eddie_Van_Halen.jpg",
- "Elton John": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/d1/Elton_John_2011_Shankbone_2.JPG/220px-Elton_John_2011_Shankbone_2.JPG",
- "Elvis Presley": "https://upload.wikimedia.org/wikipedia/commons/thumb/9/99/Elvis_Presley_promoting_Jailhouse_Rock.jpg/220px-Elvis_Presley_promoting_Jailhouse_Rock.jpg",
- "Eminem": "https://upload.wikimedia.org/wikipedia/commons/thumb/4/4a/Eminem_-_Concert_for_Valor_in_Washington%2C_D.C._Nov._11%2C_2014_%282%29_%28Cropped%29.jpg/220px-Eminem_-_Concert_for_Valor_in_Washington%2C_D.C._Nov._11%2C_2014_%282%29_%28Cropped%29.jpg",
- "Enya": "https://enya.com/wp-content/themes/enya%20full%20site/images/enya-about.jpg",
- "Flo Rida": "https://upload.wikimedia.org/wikipedia/commons/thumb/8/8b/Flo_Rida_%286924266548%29.jpg/220px-Flo_Rida_%286924266548%29.jpg",
- "Frank Sinatra": "https://upload.wikimedia.org/wikipedia/commons/thumb/a/af/Frank_Sinatra_%2757.jpg/220px-Frank_Sinatra_%2757.jpg",
- "Garth Brooks": "https://upload.wikimedia.org/wikipedia/commons/thumb/b/bc/Garth_Brooks_on_World_Tour_%28crop%29.png/220px-Garth_Brooks_on_World_Tour_%28crop%29.png",
- "George Michael": "https://upload.wikimedia.org/wikipedia/commons/thumb/f/f2/George_Michael.jpeg/220px-George_Michael.jpeg",
- "George Strait": "https://upload.wikimedia.org/wikipedia/commons/thumb/0/0c/George_Strait_2014_1.jpg/220px-George_Strait_2014_1.jpg",
- "James Taylor": "https://upload.wikimedia.org/wikipedia/commons/thumb/c/cf/James_Taylor_-_Columbia.jpg/220px-James_Taylor_-_Columbia.jpg",
- "Janet Jackson": "https://upload.wikimedia.org/wikipedia/commons/thumb/0/02/JanetJacksonUnbreakableTourSanFran2015.jpg/220px-JanetJacksonUnbreakableTourSanFran2015.jpg",
- "Jay-Z": "https://upload.wikimedia.org/wikipedia/commons/thumb/b/b6/Jay-Z.png/220px-Jay-Z.png",
- "Johnny Cash": "https://upload.wikimedia.org/wikipedia/commons/thumb/f/f2/JohnnyCash1969.jpg/220px-JohnnyCash1969.jpg",
- "Johnny Hallyday": "https://upload.wikimedia.org/wikipedia/commons/thumb/a/a1/Johnny_Hallyday_Cannes.jpg/220px-Johnny_Hallyday_Cannes.jpg",
- "Julio Iglesias": "https://upload.wikimedia.org/wikipedia/commons/thumb/e/ef/Julio_Iglesias09.jpg/220px-Julio_Iglesias09.jpg",
- "Justin Bieber": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/da/Justin_Bieber_in_2015.jpg/220px-Justin_Bieber_in_2015.jpg",
- "Justin Timberlake": "https://upload.wikimedia.org/wikipedia/commons/thumb/e/ed/Justin_Timberlake_by_Gage_Skidmore_2.jpg/220px-Justin_Timberlake_by_Gage_Skidmore_2.jpg",
- "Kanye West": "https://upload.wikimedia.org/wikipedia/commons/thumb/1/11/Kanye_West_at_the_2009_Tribeca_Film_Festival.jpg/220px-Kanye_West_at_the_2009_Tribeca_Film_Festival.jpg",
- "Katy Perry": "https://upload.wikimedia.org/wikipedia/commons/thumb/8/8a/Katy_Perry_at_Madison_Square_Garden_%2837436531092%29_%28cropped%29.jpg/220px-Katy_Perry_at_Madison_Square_Garden_%2837436531092%29_%28cropped%29.jpg",
- "Kenny G": "https://upload.wikimedia.org/wikipedia/commons/thumb/f/f4/KennyGHWOFMay2013.jpg/220px-KennyGHWOFMay2013.jpg",
- "Kenny Rogers": "https://upload.wikimedia.org/wikipedia/commons/thumb/8/8c/KennyRogers.jpg/220px-KennyRogers.jpg",
- "Lady Gaga": "https://upload.wikimedia.org/wikipedia/commons/thumb/2/2c/Lady_Gaga_interview_2016.jpg/220px-Lady_Gaga_interview_2016.jpg",
- "Lil Wayne": "https://upload.wikimedia.org/wikipedia/commons/thumb/a/a6/Lil_Wayne_%2823513397583%29.jpg/220px-Lil_Wayne_%2823513397583%29.jpg",
- "Linda Ronstadt": "https://upload.wikimedia.org/wikipedia/commons/thumb/5/50/LindaRonstadtPerforming.jpg/220px-LindaRonstadtPerforming.jpg",
- "Lionel Richie": "https://upload.wikimedia.org/wikipedia/commons/thumb/c/cd/Lionel_Richie_2017.jpg/220px-Lionel_Richie_2017.jpg",
- "Madonna": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/d1/Madonna_Rebel_Heart_Tour_2015_-_Stockholm_%2823051472299%29_%28cropped_2%29.jpg/220px-Madonna_Rebel_Heart_Tour_2015_-_Stockholm_%2823051472299%29_%28cropped_2%29.jpg",
- "Mariah Carey": "https://upload.wikimedia.org/wikipedia/commons/thumb/f/f2/Mariah_Carey_WBLS_2018_Interview_4.jpg/220px-Mariah_Carey_WBLS_2018_Interview_4.jpg",
- "Meat Loaf": "https://upload.wikimedia.org/wikipedia/commons/thumb/e/e7/Meat_Loaf.jpg/220px-Meat_Loaf.jpg",
- "Michael Jackson": "https://upload.wikimedia.org/wikipedia/commons/thumb/3/31/Michael_Jackson_in_1988.jpg/220px-Michael_Jackson_in_1988.jpg",
- "Neil Diamond": "https://upload.wikimedia.org/wikipedia/commons/thumb/f/f4/Neil_Diamond_HWOF_Aug_2012_other_%28levels_adjusted_and_cropped%29.jpg/220px-Neil_Diamond_HWOF_Aug_2012_other_%28levels_adjusted_and_cropped%29.jpg",
- "Nicki Minaj": "https://upload.wikimedia.org/wikipedia/commons/thumb/5/54/Nicki_Minaj_MTV_VMAs_4.jpg/250px-Nicki_Minaj_MTV_VMAs_4.jpg",
- "Olivia Newton-John": "https://upload.wikimedia.org/wikipedia/commons/thumb/c/c7/Olivia_Newton-John_2.jpg/220px-Olivia_Newton-John_2.jpg",
- "Paul McCartney": "https://upload.wikimedia.org/wikipedia/commons/thumb/5/5d/Paul_McCartney_-_Out_There_Concert_-_140420-5941-jikatu_%2813950091384%29.jpg/220px-Paul_McCartney_-_Out_There_Concert_-_140420-5941-jikatu_%2813950091384%29.jpg",
- "Phil Collins": "https://upload.wikimedia.org/wikipedia/commons/thumb/f/f3/1_collins.jpg/220px-1_collins.jpg",
- "Pink": "https://upload.wikimedia.org/wikipedia/commons/thumb/1/1a/P%21nk_Live_2013.jpg/220px-P%21nk_Live_2013.jpg",
- "Prince": "https://upload.wikimedia.org/wikipedia/commons/thumb/b/b2/Prince_1983_1st_Avenue.jpg/220px-Prince_1983_1st_Avenue.jpg",
- "Reba McEntire": "https://upload.wikimedia.org/wikipedia/commons/thumb/e/e0/Reba_McEntire_by_Gage_Skidmore.jpg/220px-Reba_McEntire_by_Gage_Skidmore.jpg",
- "Rihanna": "https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/Rihanna_concert_in_Washington_DC_%282%29.jpg/250px-Rihanna_concert_in_Washington_DC_%282%29.jpg",
- "Robbie Williams": "https://upload.wikimedia.org/wikipedia/commons/thumb/2/21/Robbie_Williams.jpg/220px-Robbie_Williams.jpg",
- "Rod Stewart": "https://upload.wikimedia.org/wikipedia/commons/thumb/5/57/Rod_stewart_05111976_12_400.jpg/220px-Rod_stewart_05111976_12_400.jpg",
- "Carlos Santana": "https://upload.wikimedia.org/wikipedia/commons/thumb/5/54/Santana_2010.jpg/220px-Santana_2010.jpg",
- "Shania Twain": "https://upload.wikimedia.org/wikipedia/commons/thumb/e/ee/ShaniaTwainJunoAwardsMar2011.jpg/220px-ShaniaTwainJunoAwardsMar2011.jpg",
- "Stevie Wonder": "https://upload.wikimedia.org/wikipedia/commons/thumb/5/54/Stevie_Wonder_1973.JPG/220px-Stevie_Wonder_1973.JPG",
- "Tak Matsumoto": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/da/B%27Z_at_Best_Buy_Theater_NYC_-_9-30-12_-_22.jpg/220px-B%27Z_at_Best_Buy_Theater_NYC_-_9-30-12_-_22.jpg",
- "Taylor Swift": "https://upload.wikimedia.org/wikipedia/commons/thumb/2/25/Taylor_Swift_112_%2818119055110%29_%28cropped%29.jpg/220px-Taylor_Swift_112_%2818119055110%29_%28cropped%29.jpg",
- "Tim McGraw": "https://upload.wikimedia.org/wikipedia/commons/thumb/5/5f/Tim_McGraw_October_24_2015.jpg/220px-Tim_McGraw_October_24_2015.jpg",
- "Tina Turner": "https://upload.wikimedia.org/wikipedia/commons/thumb/1/10/Tina_turner_21021985_01_350.jpg/250px-Tina_turner_21021985_01_350.jpg",
- "Tom Petty": "https://upload.wikimedia.org/wikipedia/commons/thumb/5/5a/Tom_Petty_Live_in_Horsens_%28cropped2%29.jpg/220px-Tom_Petty_Live_in_Horsens_%28cropped2%29.jpg",
- "Tupac Shakur": "https://upload.wikimedia.org/wikipedia/en/thumb/b/b5/Tupac_Amaru_Shakur2.jpg/220px-Tupac_Amaru_Shakur2.jpg",
- "Usher": "https://upload.wikimedia.org/wikipedia/commons/thumb/f/fa/Usher_Cannes_2016_retusche.jpg/220px-Usher_Cannes_2016_retusche.jpg",
- "Whitney Houston": "https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Whitney_Houston_Welcome_Home_Heroes_1_cropped.jpg/220px-Whitney_Houston_Welcome_Home_Heroes_1_cropped.jpg",
- "Wolfgang Van Halen": "https://upload.wikimedia.org/wikipedia/commons/thumb/0/0c/Wolfgang_Van_Halen_Different_Kind_of_Truth_2012.jpg/220px-Wolfgang_Van_Halen_Different_Kind_of_Truth_2012.jpg"
-}
+[
+ "Adele",
+ "Aerosmith",
+ "Aretha Franklin",
+ "Ayumi Hamasaki",
+ "B'z",
+ "Barbra Streisand",
+ "Barry Manilow",
+ "Barry White",
+ "Beyonce",
+ "Billy Joel",
+ "Bob Dylan",
+ "Bob Marley",
+ "Bob Seger",
+ "Bon Jovi",
+ "Britney Spears",
+ "Bruce Springsteen",
+ "Bruno Mars",
+ "Bryan Adams",
+ "Celine Dion",
+ "Cher",
+ "Christina Aguilera",
+ "David Bowie",
+ "Donna Summer",
+ "Drake",
+ "Ed Sheeran",
+ "Elton John",
+ "Elvis Presley",
+ "Eminem",
+ "Enya",
+ "Flo Rida",
+ "Frank Sinatra",
+ "Garth Brooks",
+ "George Michael",
+ "George Strait",
+ "James Taylor",
+ "Janet Jackson",
+ "Jay-Z",
+ "Johnny Cash",
+ "Johnny Hallyday",
+ "Julio Iglesias",
+ "Justin Bieber",
+ "Justin Timberlake",
+ "Kanye West",
+ "Katy Perry",
+ "Kenny G",
+ "Kenny Rogers",
+ "Lady Gaga",
+ "Lil Wayne",
+ "Linda Ronstadt",
+ "Lionel Richie",
+ "Madonna",
+ "Mariah Carey",
+ "Meat Loaf",
+ "Michael Jackson",
+ "Neil Diamond",
+ "Nicki Minaj",
+ "Olivia Newton-John",
+ "Paul McCartney",
+ "Phil Collins",
+ "Pink",
+ "Prince",
+ "Reba McEntire",
+ "Rihanna",
+ "Robbie Williams",
+ "Rod Stewart",
+ "Santana",
+ "Shania Twain",
+ "Stevie Wonder",
+ "Taylor Swift",
+ "Tim McGraw",
+ "Tina Turner",
+ "Tom Petty",
+ "Tupac Shakur",
+ "Usher",
+ "Van Halen",
+ "Whitney Houston"
+]
diff --git a/bot/utils/checks.py b/bot/utils/checks.py
index 19f64ff9f..db56c347c 100644
--- a/bot/utils/checks.py
+++ b/bot/utils/checks.py
@@ -1,6 +1,8 @@
+import datetime
import logging
+from typing import Callable, Iterable
-from discord.ext.commands import Context
+from discord.ext.commands import BucketType, Cog, Command, CommandOnCooldown, Context, Cooldown, CooldownMapping
log = logging.getLogger(__name__)
@@ -36,9 +38,53 @@ def without_role_check(ctx: Context, *role_ids: int) -> bool:
return check
-def in_channel_check(ctx: Context, channel_id: int) -> bool:
- """Checks if the command was executed inside of the specified channel."""
- check = ctx.channel.id == channel_id
+def in_channel_check(ctx: Context, *channel_ids: int) -> bool:
+ """Checks if the command was executed inside the list of specified channels."""
+ check = ctx.channel.id in channel_ids
log.trace(f"{ctx.author} tried to call the '{ctx.command.name}' command. "
f"The result of the in_channel check was {check}.")
return check
+
+
+def cooldown_with_role_bypass(rate: int, per: float, type: BucketType = BucketType.default, *,
+ bypass_roles: Iterable[int]) -> Callable:
+ """
+ Applies a cooldown to a command, but allows members with certain roles to be ignored.
+
+ NOTE: this replaces the `Command.before_invoke` callback, which *might* introduce problems in the future.
+ """
+ # make it a set so lookup is hash based
+ bypass = set(bypass_roles)
+
+ # this handles the actual cooldown logic
+ buckets = CooldownMapping(Cooldown(rate, per, type))
+
+ # will be called after the command has been parse but before it has been invoked, ensures that
+ # the cooldown won't be updated if the user screws up their input to the command
+ async def predicate(cog: Cog, ctx: Context) -> None:
+ nonlocal bypass, buckets
+
+ if any(role.id in bypass for role in ctx.author.roles):
+ return
+
+ # cooldown logic, taken from discord.py internals
+ current = ctx.message.created_at.replace(tzinfo=datetime.timezone.utc).timestamp()
+ bucket = buckets.get_bucket(ctx.message)
+ retry_after = bucket.update_rate_limit(current)
+ if retry_after:
+ raise CommandOnCooldown(bucket, retry_after)
+
+ def wrapper(command: Command) -> Command:
+ # NOTE: this could be changed if a subclass of Command were to be used. I didn't see the need for it
+ # so I just made it raise an error when the decorator is applied before the actual command object exists.
+ #
+ # if the `before_invoke` detail is ever a problem then I can quickly just swap over.
+ if not isinstance(command, Command):
+ raise TypeError('Decorator `cooldown_with_role_bypass` must be applied after the command decorator. '
+ 'This means it has to be above the command decorator in the code.')
+
+ command._before_invoke = predicate
+
+ return command
+
+ return wrapper
diff --git a/bot/utils/moderation.py b/bot/utils/moderation.py
deleted file mode 100644
index 7860f14a1..000000000
--- a/bot/utils/moderation.py
+++ /dev/null
@@ -1,72 +0,0 @@
-import logging
-from datetime import datetime
-from typing import Optional, Union
-
-from discord import Member, Object, User
-from discord.ext.commands import Context
-
-from bot.api import ResponseCodeError
-from bot.constants import Keys
-
-log = logging.getLogger(__name__)
-
-HEADERS = {"X-API-KEY": Keys.site_api}
-
-
-async def post_infraction(
- ctx: Context,
- user: Union[Member, Object, User],
- type: str,
- reason: str,
- expires_at: datetime = None,
- hidden: bool = False,
- active: bool = True,
-) -> Optional[dict]:
- """Posts an infraction to the API."""
- payload = {
- "actor": ctx.message.author.id,
- "hidden": hidden,
- "reason": reason,
- "type": type,
- "user": user.id,
- "active": active
- }
- if expires_at:
- payload['expires_at'] = expires_at.isoformat()
-
- try:
- response = await ctx.bot.api_client.post('bot/infractions', json=payload)
- except ResponseCodeError as exp:
- if exp.status == 400 and 'user' in exp.response_json:
- log.info(
- f"{ctx.author} tried to add a {type} infraction to `{user.id}`, "
- "but that user id was not found in the database."
- )
- await ctx.send(f":x: Cannot add infraction, the specified user is not known to the database.")
- return
- else:
- log.exception("An unexpected ResponseCodeError occurred while adding an infraction:")
- await ctx.send(":x: There was an error adding the infraction.")
- return
-
- return response
-
-
-async def already_has_active_infraction(ctx: Context, user: Union[Member, Object, User], type: str) -> bool:
- """Checks if a user already has an active infraction of the given type."""
- active_infractions = await ctx.bot.api_client.get(
- 'bot/infractions',
- params={
- 'active': 'true',
- 'type': type,
- 'user__id': str(user.id)
- }
- )
- if active_infractions:
- await ctx.send(
- f":x: According to my records, this user already has a {type} infraction. "
- f"See infraction **#{active_infractions[0]['id']}**."
- )
- return True
- else:
- return False
diff --git a/bot/utils/time.py b/bot/utils/time.py
index da28f2c76..2aea2c099 100644
--- a/bot/utils/time.py
+++ b/bot/utils/time.py
@@ -1,5 +1,6 @@
import asyncio
import datetime
+from typing import Optional
import dateutil.parser
from dateutil.relativedelta import relativedelta
@@ -34,6 +35,9 @@ def humanize_delta(delta: relativedelta, precision: str = "seconds", max_units:
precision specifies the smallest unit of time to include (e.g. "seconds", "minutes").
max_units specifies the maximum number of units of time to include (e.g. 1 may include days but not hours).
"""
+ if max_units <= 0:
+ raise ValueError("max_units must be positive")
+
units = (
("years", delta.years),
("months", delta.months),
@@ -83,15 +87,20 @@ def time_since(past_datetime: datetime.datetime, precision: str = "seconds", max
return f"{humanized} ago"
-def parse_rfc1123(time_str: str) -> datetime.datetime:
+def parse_rfc1123(stamp: str) -> datetime.datetime:
"""Parse RFC1123 time string into datetime."""
- return datetime.datetime.strptime(time_str, RFC1123_FORMAT).replace(tzinfo=datetime.timezone.utc)
+ return datetime.datetime.strptime(stamp, RFC1123_FORMAT).replace(tzinfo=datetime.timezone.utc)
# Hey, this could actually be used in the off_topic_names and reddit cogs :)
-async def wait_until(time: datetime.datetime) -> None:
- """Wait until a given time."""
- delay = time - datetime.datetime.utcnow()
+async def wait_until(time: datetime.datetime, start: Optional[datetime.datetime] = None) -> None:
+ """
+ Wait until a given time.
+
+ :param time: A datetime.datetime object to wait until.
+ :param start: The start from which to calculate the waiting duration. Defaults to UTC time.
+ """
+ delay = time - (start or datetime.datetime.utcnow())
delay_seconds = delay.total_seconds()
# Incorporate a small delay so we don't rapid-fire the event due to time precision errors
diff --git a/config-default.yml b/config-default.yml
index 38b26f64f..bce6ea266 100644
--- a/config-default.yml
+++ b/config-default.yml
@@ -37,6 +37,10 @@ style:
new: "\U0001F195"
cross_mark: "\u274C"
+ upvotes: "<:upvotes:638729835245731840>"
+ comments: "<:comments:638729835073765387>"
+ user: "<:user:638729835442602003>"
+
icons:
crown_blurple: "https://cdn.discordapp.com/emojis/469964153289965568.png"
crown_green: "https://cdn.discordapp.com/emojis/469964154719961088.png"
@@ -90,11 +94,12 @@ guild:
channels:
admins: &ADMINS 365960823622991872
+ admin_spam: &ADMIN_SPAM 563594791770914816
announcements: 354619224620138496
big_brother_logs: &BBLOGS 468507907357409333
bot: 267659945086812160
checkpoint_test: 422077681434099723
- defcon: 464469101889454091
+ defcon: &DEFCON 464469101889454091
devlog: &DEVLOG 622895325144940554
devtest: &DEVTEST 414574275865870337
help_0: 303906576991780866
@@ -105,13 +110,17 @@ guild:
help_5: 454941769734422538
help_6: 587375753306570782
help_7: 587375768556797982
- helpers: 385474242440986624
+ helpers: &HELPERS 385474242440986624
message_log: &MESSAGE_LOG 467752170159079424
+ meta: 429409067623251969
+ mod_spam: &MOD_SPAM 620607373828030464
+ mods: &MODS 305126844661760000
mod_alerts: 473092532147060736
modlog: &MODLOG 282638479504965634
off_topic_0: 291284109232308226
off_topic_1: 463035241142026251
off_topic_2: 463035268514185226
+ organisation: &ORGANISATION 551789653284356126
python: 267624335836053506
reddit: 458224812528238616
staff_lounge: &STAFF_LOUNGE 464905259261755392
@@ -120,6 +129,7 @@ guild:
user_event_a: &USER_EVENT_A 592000283102674944
verification: 352442727016693763
+ staff_channels: [*ADMINS, *ADMIN_SPAM, *MOD_SPAM, *MODS, *HELPERS, *ORGANISATION, *DEFCON]
ignored: [*ADMINS, *MESSAGE_LOG, *MODLOG]
roles:
@@ -141,6 +151,7 @@ guild:
webhooks:
talent_pool: 569145364800602132
big_brother: 569133704568373283
+ reddit: 635408384794951680
filter:
@@ -282,7 +293,7 @@ anti_spam:
rules:
attachments:
interval: 10
- max: 3
+ max: 9
burst:
interval: 10
@@ -322,8 +333,28 @@ anti_spam:
max: 3
+anti_malware:
+ whitelist:
+ - '.3gp'
+ - '.3g2'
+ - '.avi'
+ - '.bmp'
+ - '.gif'
+ - '.h264'
+ - '.jpg'
+ - '.jpeg'
+ - '.m4v'
+ - '.mkv'
+ - '.mov'
+ - '.mp4'
+ - '.mpeg'
+ - '.mpg'
+ - '.png'
+ - '.tiff'
+ - '.wmv'
+
+
reddit:
- request_delay: 60
subreddits:
- 'r/Python'
@@ -347,6 +378,10 @@ free:
cooldown_rate: 1
cooldown_per: 60.0
+mention:
+ message_timeout: 300
+ reset_delay: 5
+
redirect_output:
delete_invocation: true
delete_delay: 15
diff --git a/docker-compose.yml b/docker-compose.yml
index 9684a3c62..f79fdba58 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -6,7 +6,7 @@ version: "3.7"
services:
postgres:
- image: postgres:11-alpine
+ image: postgres:12-alpine
environment:
POSTGRES_DB: pysite
POSTGRES_PASSWORD: pysite
diff --git a/tests/README.md b/tests/README.md
new file mode 100644
index 000000000..6ab9bc93e
--- /dev/null
+++ b/tests/README.md
@@ -0,0 +1,213 @@
+# Testing our Bot
+
+Our bot is one of the most important tools we have for running our community. As we don't want that tool break, we decided that we wanted to write unit tests for it. We hope that in the future, we'll have a 100% test coverage for the bot. This guide will help you get started with writing the tests needed to achieve that.
+
+_**Note:** This is a practical guide to getting started with writing tests for our bot, not a general introduction to writing unit tests in Python. If you're looking for a more general introduction, you may like Corey Schafer's [Python Tutorial: Unit Testing Your Code with the unittest Module](https://www.youtube.com/watch?v=6tNS--WetLI) or Ned Batchelder's PyCon talk [Getting Started Testing](https://www.youtube.com/watch?v=FxSsnHeWQBY)._
+
+## Tools
+
+We are using the following modules and packages for our unit tests:
+
+- [unittest](https://docs.python.org/3/library/unittest.html) (standard library)
+- [unittest.mock](https://docs.python.org/3/library/unittest.mock.html) (standard library)
+- [coverage.py](https://coverage.readthedocs.io/en/stable/)
+
+To ensure the results you obtain on your personal machine are comparable to those generated in the Azure pipeline, please make sure to run your tests with the virtual environment defined by our [Pipfile](/Pipfile). To run your tests with `pipenv`, we've provided two "scripts" shortcuts:
+
+- `pipenv run test` will run `unittest` with `coverage.py`
+- `pipenv run report` will generate a coverage report of the tests you've run with `pipenv run test`. If you append the `-m` flag to this command, the report will include the lines and branches not covered by tests in addition to the test coverage report.
+
+If you want a coverage report, make sure to run the tests with `pipenv run test` *first*.
+
+## Writing tests
+
+Since consistency is an important consideration for collaborative projects, we have written some guidelines on writing tests for the bot. In addition to these guidelines, it's a good idea to look at the existing code base for examples (e.g., [`test_converters.py`](/tests/bot/test_converters.py)).
+
+### File and directory structure
+
+To organize our test suite, we have chosen to mirror the directory structure of [`bot`](/bot/) in the [`tests`](/tests/) subdirectory. This makes it easy to find the relevant tests by providing a natural grouping of files. More general testing files, such as [`helpers.py`](/tests/helpers.py) are located directly in the `tests` subdirectory.
+
+All files containing tests should have a filename starting with `test_` to make sure `unittest` will discover them. This prefix is typically followed by the name of the file the tests are written for. If needed, a test file can contain multiple test classes, both to provide structure and to be able to provide different fixtures/set-up methods for different groups of tests.
+
+### Writing independent tests
+
+When writing unit tests, it's really important to make sure that each test that you write runs independently from all of the other tests. This both means that the code you write for one test shouldn't influence the result of another test and that if one tests fails, the other tests should still run.
+
+The basis for this is that when you write a test method, it should really only test a single aspect of the thing you're testing. This often means that you do not write one large test that tests "everything" that can be tested for a function, but rather that you write multiple smaller tests that each test a specific branch/path/condition of the function under scrutiny.
+
+To make sure you're not repeating the same set-up steps in all these smaller tests, `unittest` provides fixtures that are executed before and after each test is run. In addition to test fixtures, it also provides special set-up and clean-up methods that are run before the first test in a test class or after the last test of that class has been run. For more information, see the documentation for [`unittest.TestCase`](https://docs.python.org/3/library/unittest.html#unittest.TestCase).
+
+#### Method names and docstrings
+
+As you can probably imagine, writing smaller, independent tests also results in a large number of tests. To make sure that it's easy to see which test does what, it is incredibly important to use good method names to identify what each test is doing. A general guideline is that the name should capture the goal of your test: What is this test method trying to assert?
+
+In addition to good method names, it's also really important to write a good *single-line* docstring. The `unittest` module will print such a single-line docstring along with the method name in the output it gives when a test fails. This means that a good docstring that really captures the purpose of the test makes it much easier to quickly make sense of output.
+
+#### Using self.subTest for independent subtests
+
+Another thing that you will probably encounter is that you want to test a function against a list of input and output values. Given the section on writing independent tests, you may now be tempted to copy-paste the same test method over and over again, once for each unique value that you want to test. However, that would result in a lot of duplicate code that is hard to maintain.
+
+Luckily, `unittest` provides a good alternative to that: the [`subTest`](https://docs.python.org/3/library/unittest.html#distinguishing-test-iterations-using-subtests) context manager. This method is often used in conjunction with a `for`-loop iterating of a collection of values that we want to test a function against and it provides two important features. First, it will make sure that if an assertion statements fails on one of the iterations, the other iterations are still run. The other important feature it provides is that it will distinguish the iterations from each other in the output.
+
+This is an example of `TestCase.subTest` in action (taken from [`test_converters.py`](/tests/bot/test_converters.py)):
+
+```py
+ def test_tag_content_converter_for_valid(self):
+ """TagContentConverter should return correct values for valid input."""
+ test_values = (
+ ('hello', 'hellpo'),
+ (' h ello ', 'h ello'),
+ )
+
+ for content, expected_conversion in test_values:
+ with self.subTest(content=content, expected_conversion=expected_conversion):
+ conversion = asyncio.run(TagContentConverter.convert(self.context, content))
+ self.assertEqual(conversion, expected_conversion)
+```
+
+It's important to note the keyword arguments we provide to the `self.subTest` context manager: These keyword arguments and their values will printed in the output when one of the subtests fail, making sure we know *which* subTest failed:
+
+```
+....................................................................
+======================================================================
+FAIL: test_tag_content_converter_for_valid (tests.bot.test_converters.ConverterTests) (content='hello', expected_conversion='hellpo')
+TagContentConverter should return correct values for valid input.
+----------------------------------------------------------------------
+
+# ...
+```
+
+## Mocking
+
+As we are trying to test our "units" of code independently, we want to make sure that we do not rely objects and data generated by "external" code. If we we did, then we wouldn't know if the failure we're observing was caused by the code we are actually trying to test or something external to it.
+
+
+However, the features that we are trying to test often depend on those objects generated by external pieces of code. It would be difficult to test a bot command without having access to a `Context` instance. Fortunately, there's a solution for that: we use fake objects that act like the true object. We call these fake objects "mocks".
+
+To create these mock object, we mainly use the [`unittest.mock`](https://docs.python.org/3/library/unittest.mock.html) module. In addition, we have also defined a couple of specialized mock objects that mock specific `discord.py` types (see the section on the below.).
+
+An example of mocking is when we provide a command with a mocked version of `discord.ext.commands.Context` object instead of a real `Context` object. This makes sure we can then check (_assert_) if the `send` method of the mocked Context object was called with the correct message content (without having to send a real message to the Discord API!):
+
+```py
+import asyncio
+import unittest
+
+from bot.cogs import bot
+from tests.helpers import MockBot, MockContext
+
+
+class BotCogTests(unittest.TestCase):
+ def test_echo_command_correctly_echoes_arguments(self):
+ """Test if the `!echo <text>` command correctly echoes the content."""
+ mocked_bot = MockBot()
+ bot_cog = bot.Bot(mocked_bot)
+
+ mocked_context = MockContext()
+
+ text = "Hello! This should be echoed!"
+
+ asyncio.run(bot_cog.echo_command.callback(bot_cog, mocked_context, text=text))
+
+ mocked_context.send.assert_called_with(text)
+```
+
+### Mocking coroutines
+
+By default, the `unittest.mock.Mock` and `unittest.mock.MagicMock` classes cannot mock coroutines, since the `__call__` method they provide is synchronous. In anticipation of the `AsyncMock` that will be [introduced in Python 3.8](https://docs.python.org/3.9/whatsnew/3.8.html#unittest), we have added an `AsyncMock` helper to [`helpers.py`](/tests/helpers.py). Do note that this drop-in replacement only implements an asynchronous `__call__` method, not the additional assertions that will come with the new `AsyncMock` type in Python 3.8.
+
+### Special mocks for some `discord.py` types
+
+To quote Ned Batchelder, Mock objects are "automatic chameleons". This means that they will happily allow the access to any attribute or method and provide a mocked value in return. One downside to this is that if the code you are testing gets the name of the attribute wrong, your mock object will not complain and the test may still pass.
+
+In order to avoid that, we have defined a number of Mock types in [`helpers.py`](/tests/helpers.py) that follow the specifications of the actual Discord types they are mocking. This means that trying to access an attribute or method on a mocked object that does not exist on the equivalent `discord.py` object will result in an `AttributeError`. In addition, these mocks have some sensible defaults and **pass `isinstance` checks for the types they are mocking**.
+
+These special mocks are added when they are needed, so if you think it would be sensible to add another one, feel free to propose one in your PR.
+
+**Note:** These mock types only "know" the attributes that are set by default when these `discord.py` types are first initialized. If you need to work with dynamically set attributes that are added after initialization, you can still explicitly mock them:
+
+```py
+import unittest.mock
+from tests.helpers import MockGuild
+
+guild = MockGuild()
+guild.some_attribute = unittest.mock.MagicMock()
+```
+
+The attribute `some_attribute` will now be accessible as a `MagicMock` on the mocked object.
+
+---
+
+## Some considerations
+
+Finally, there are some considerations to make when writing tests, both for writing tests in general and for writing tests for our bot in particular.
+
+### Test coverage is a starting point
+
+Having test coverage is a good starting point for unit testing: If a part of your code was not covered by a test, we know that we have not tested it properly. The reverse is unfortunately not true: Even if the code we are testing has 100% branch coverage, it does not mean it's fully tested or guaranteed to work.
+
+One problem is that 100% branch coverage may be misleading if we haven't tested our code against all the realistic input it may get in production. For instance, take a look at the following `member_information` function and the test we've written for it:
+
+```py
+import datetime
+import unittest
+import unittest.mock
+
+
+def member_information(member):
+ joined = member.joined.stfptime("%d-%m-%Y") if member.joined else "unknown"
+ return f"{member.name} (joined: {joined})"
+
+
+class FunctionsTests(unittest.TestCase):
+ def test_member_information(self):
+ member = unittest.mock.Mock()
+ member.name = "lemon"
+ member.joined = None
+ self.assertEqual(member_information(member), "lemon (joined: unknown)")
+```
+
+If you were to run this test, not only would the function pass the test, `coverage.py` will also tell us that the test provides 100% branch coverage for the function. Can you spot the bug the test suite did not catch?
+
+The problem here is that we have only tested our function with a member object that had `None` for the `member.joined` attribute. This means that `member.joined.stfptime("%d-%m-%Y")` was never executed during our test, leading to us missing the spelling mistake in `stfptime` (it should be `strftime`).
+
+Adding another test would not increase the test coverage we have, but it does ensure that we'll notice that this function can fail with realistic data:
+
+```py
+# (...)
+class FunctionsTests(unittest.TestCase):
+ # (...)
+ def test_member_information_with_join_datetime(self):
+ member = unittest.mock.Mock()
+ member.name = "lemon"
+ member.joined = datetime.datetime(year=2019, month=10, day=10)
+ self.assertEqual(member_information(member), "lemon (joined: 10-10-2019)")
+```
+
+Output:
+```
+.E
+======================================================================
+ERROR: test_member_information_with_join_datetime (tests.test_functions.FunctionsTests)
+----------------------------------------------------------------------
+Traceback (most recent call last):
+ File "/home/pydis/playground/tests/test_functions.py", line 23, in test_member_information_with_join_datetime
+ self.assertEqual(member_information(member), "lemon (joined: 10-10-2019)")
+ File "/home/pydis/playground/tests/test_functions.py", line 8, in member_information
+ joined = member.joined.stfptime("%d-%m-%Y") if member.joined else "unknown"
+AttributeError: 'datetime.datetime' object has no attribute 'stfptime'
+
+----------------------------------------------------------------------
+Ran 2 tests in 0.003s
+
+FAILED (errors=1)
+```
+
+What's more, even if the spelling mistake would not have been there, the first test did not test if the `member_information` function formatted the `member.join` according to the output we actually want to see.
+
+All in all, it's not only important to consider if all statements or branches were touched at least once with a test, but also if they are extensively tested in all situations that may happen in production.
+
+### Unit Testing vs Integration Testing
+
+Another restriction of unit testing is that it tests, well, in units. Even if we can guarantee that the units work as they should independently, we have no guarantee that they will actually work well together. Even more, while the mocking described above gives us a lot of flexibility in factoring out external code, we are work under the implicit assumption that we fully understand those external parts and utilize it correctly. What if our mocked `Context` object works with a `send` method, but `discord.py` has changed it to a `send_message` method in a recent update? It could mean our tests are passing, but the code it's testing still doesn't work in production.
+
+The answer to this is that we also need to make sure that the individual parts come together into a working application. In addition, we will also need to make sure that the application communicates correctly with external applications. Since we currently have no automated integration tests or functional tests, that means **it's still very important to fire up the bot and test the code you've written manually** in addition to the unit tests you've written.
diff --git a/tests/__init__.py b/tests/__init__.py
index e69de29bb..2228110ad 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -0,0 +1,5 @@
+import logging
+
+
+log = logging.getLogger()
+log.setLevel(logging.CRITICAL)
diff --git a/tests/base.py b/tests/base.py
new file mode 100644
index 000000000..029a249ed
--- /dev/null
+++ b/tests/base.py
@@ -0,0 +1,67 @@
+import logging
+import unittest
+from contextlib import contextmanager
+
+
+class _CaptureLogHandler(logging.Handler):
+ """
+ A logging handler capturing all (raw and formatted) logging output.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.records = []
+
+ def emit(self, record):
+ self.records.append(record)
+
+
+class LoggingTestCase(unittest.TestCase):
+ """TestCase subclass that adds more logging assertion tools."""
+
+ @contextmanager
+ def assertNotLogs(self, logger=None, level=None, msg=None):
+ """
+ Asserts that no logs of `level` and higher were emitted by `logger`.
+
+ You can specify a specific `logger`, the minimum `logging` level we want to watch and a
+ custom `msg` to be added to the `AssertionError` if thrown. If the assertion fails, the
+ recorded log records will be outputted with the `AssertionError` message. The context
+ manager does not yield a live `look` into the logging records, since we use this context
+ manager when we're testing under the assumption that no log records will be emitted.
+ """
+ if not isinstance(logger, logging.Logger):
+ logger = logging.getLogger(logger)
+
+ if level:
+ level = logging._nameToLevel.get(level, level)
+ else:
+ level = logging.INFO
+
+ handler = _CaptureLogHandler()
+ old_handlers = logger.handlers[:]
+ old_level = logger.level
+ old_propagate = logger.propagate
+
+ logger.handlers = [handler]
+ logger.setLevel(level)
+ logger.propagate = False
+
+ try:
+ yield
+ except Exception as exc:
+ raise exc
+ finally:
+ logger.handlers = old_handlers
+ logger.propagate = old_propagate
+ logger.setLevel(old_level)
+
+ if handler.records:
+ level_name = logging.getLevelName(level)
+ n_logs = len(handler.records)
+ base_message = f"{n_logs} logs of {level_name} or higher were triggered on {logger.name}:\n"
+ records = [str(record) for record in handler.records]
+ record_message = "\n".join(records)
+ standard_message = self._truncateMessage(base_message, record_message)
+ msg = self._formatMessage(msg, standard_message)
+ self.fail(msg)
diff --git a/tests/cogs/__init__.py b/tests/bot/__init__.py
index e69de29bb..e69de29bb 100644
--- a/tests/cogs/__init__.py
+++ b/tests/bot/__init__.py
diff --git a/tests/cogs/sync/__init__.py b/tests/bot/cogs/__init__.py
index e69de29bb..e69de29bb 100644
--- a/tests/cogs/sync/__init__.py
+++ b/tests/bot/cogs/__init__.py
diff --git a/tests/rules/__init__.py b/tests/bot/cogs/sync/__init__.py
index e69de29bb..e69de29bb 100644
--- a/tests/rules/__init__.py
+++ b/tests/bot/cogs/sync/__init__.py
diff --git a/tests/bot/cogs/sync/test_roles.py b/tests/bot/cogs/sync/test_roles.py
new file mode 100644
index 000000000..27ae27639
--- /dev/null
+++ b/tests/bot/cogs/sync/test_roles.py
@@ -0,0 +1,126 @@
+import unittest
+
+from bot.cogs.sync.syncers import Role, get_roles_for_sync
+
+
+class GetRolesForSyncTests(unittest.TestCase):
+ """Tests constructing the roles to synchronize with the site."""
+
+ def test_get_roles_for_sync_empty_return_for_equal_roles(self):
+ """No roles should be synced when no diff is found."""
+ api_roles = {Role(id=41, name='name', colour=33, permissions=0x8, position=1)}
+ guild_roles = {Role(id=41, name='name', colour=33, permissions=0x8, position=1)}
+
+ self.assertEqual(
+ get_roles_for_sync(guild_roles, api_roles),
+ (set(), set(), set())
+ )
+
+ def test_get_roles_for_sync_returns_roles_to_update_with_non_id_diff(self):
+ """Roles to be synced are returned when non-ID attributes differ."""
+ api_roles = {Role(id=41, name='old name', colour=35, permissions=0x8, position=1)}
+ guild_roles = {Role(id=41, name='new name', colour=33, permissions=0x8, position=2)}
+
+ self.assertEqual(
+ get_roles_for_sync(guild_roles, api_roles),
+ (set(), guild_roles, set())
+ )
+
+ def test_get_roles_only_returns_roles_that_require_update(self):
+ """Roles that require an update should be returned as the second tuple element."""
+ api_roles = {
+ Role(id=41, name='old name', colour=33, permissions=0x8, position=1),
+ Role(id=53, name='other role', colour=55, permissions=0, position=3)
+ }
+ guild_roles = {
+ Role(id=41, name='new name', colour=35, permissions=0x8, position=2),
+ Role(id=53, name='other role', colour=55, permissions=0, position=3)
+ }
+
+ self.assertEqual(
+ get_roles_for_sync(guild_roles, api_roles),
+ (
+ set(),
+ {Role(id=41, name='new name', colour=35, permissions=0x8, position=2)},
+ set(),
+ )
+ )
+
+ def test_get_roles_returns_new_roles_in_first_tuple_element(self):
+ """Newly created roles are returned as the first tuple element."""
+ api_roles = {
+ Role(id=41, name='name', colour=35, permissions=0x8, position=1),
+ }
+ guild_roles = {
+ Role(id=41, name='name', colour=35, permissions=0x8, position=1),
+ Role(id=53, name='other role', colour=55, permissions=0, position=2)
+ }
+
+ self.assertEqual(
+ get_roles_for_sync(guild_roles, api_roles),
+ (
+ {Role(id=53, name='other role', colour=55, permissions=0, position=2)},
+ set(),
+ set(),
+ )
+ )
+
+ def test_get_roles_returns_roles_to_update_and_new_roles(self):
+ """Newly created and updated roles should be returned together."""
+ api_roles = {
+ Role(id=41, name='old name', colour=35, permissions=0x8, position=1),
+ }
+ guild_roles = {
+ Role(id=41, name='new name', colour=40, permissions=0x16, position=2),
+ Role(id=53, name='other role', colour=55, permissions=0, position=3)
+ }
+
+ self.assertEqual(
+ get_roles_for_sync(guild_roles, api_roles),
+ (
+ {Role(id=53, name='other role', colour=55, permissions=0, position=3)},
+ {Role(id=41, name='new name', colour=40, permissions=0x16, position=2)},
+ set(),
+ )
+ )
+
+ def test_get_roles_returns_roles_to_delete(self):
+ """Roles to be deleted should be returned as the third tuple element."""
+ api_roles = {
+ Role(id=41, name='name', colour=35, permissions=0x8, position=1),
+ Role(id=61, name='to delete', colour=99, permissions=0x9, position=2),
+ }
+ guild_roles = {
+ Role(id=41, name='name', colour=35, permissions=0x8, position=1),
+ }
+
+ self.assertEqual(
+ get_roles_for_sync(guild_roles, api_roles),
+ (
+ set(),
+ set(),
+ {Role(id=61, name='to delete', colour=99, permissions=0x9, position=2)},
+ )
+ )
+
+ def test_get_roles_returns_roles_to_delete_update_and_new_roles(self):
+ """When roles were added, updated, and removed, all of them are returned properly."""
+ api_roles = {
+ Role(id=41, name='not changed', colour=35, permissions=0x8, position=1),
+ Role(id=61, name='to delete', colour=99, permissions=0x9, position=2),
+ Role(id=71, name='to update', colour=99, permissions=0x9, position=3),
+ }
+ guild_roles = {
+ Role(id=41, name='not changed', colour=35, permissions=0x8, position=1),
+ Role(id=81, name='to create', colour=99, permissions=0x9, position=4),
+ Role(id=71, name='updated', colour=101, permissions=0x5, position=3),
+ }
+
+ self.assertEqual(
+ get_roles_for_sync(guild_roles, api_roles),
+ (
+ {Role(id=81, name='to create', colour=99, permissions=0x9, position=4)},
+ {Role(id=71, name='updated', colour=101, permissions=0x5, position=3)},
+ {Role(id=61, name='to delete', colour=99, permissions=0x9, position=2)},
+ )
+ )
diff --git a/tests/bot/cogs/sync/test_users.py b/tests/bot/cogs/sync/test_users.py
new file mode 100644
index 000000000..ccaf67490
--- /dev/null
+++ b/tests/bot/cogs/sync/test_users.py
@@ -0,0 +1,84 @@
+import unittest
+
+from bot.cogs.sync.syncers import User, get_users_for_sync
+
+
+def fake_user(**kwargs):
+ kwargs.setdefault('id', 43)
+ kwargs.setdefault('name', 'bob the test man')
+ kwargs.setdefault('discriminator', 1337)
+ kwargs.setdefault('avatar_hash', None)
+ kwargs.setdefault('roles', (666,))
+ kwargs.setdefault('in_guild', True)
+ return User(**kwargs)
+
+
+class GetUsersForSyncTests(unittest.TestCase):
+ """Tests constructing the users to synchronize with the site."""
+
+ def test_get_users_for_sync_returns_nothing_for_empty_params(self):
+ """When no users are given, none are returned."""
+ self.assertEqual(
+ get_users_for_sync({}, {}),
+ (set(), set())
+ )
+
+ def test_get_users_for_sync_returns_nothing_for_equal_users(self):
+ """When no users are updated, none are returned."""
+ api_users = {43: fake_user()}
+ guild_users = {43: fake_user()}
+
+ self.assertEqual(
+ get_users_for_sync(guild_users, api_users),
+ (set(), set())
+ )
+
+ def test_get_users_for_sync_returns_users_to_update_on_non_id_field_diff(self):
+ """When a non-ID-field differs, the user to update is returned."""
+ api_users = {43: fake_user()}
+ guild_users = {43: fake_user(name='new fancy name')}
+
+ self.assertEqual(
+ get_users_for_sync(guild_users, api_users),
+ (set(), {fake_user(name='new fancy name')})
+ )
+
+ def test_get_users_for_sync_returns_users_to_create_with_new_ids_on_guild(self):
+ """When new users join the guild, they are returned as the first tuple element."""
+ api_users = {43: fake_user()}
+ guild_users = {43: fake_user(), 63: fake_user(id=63)}
+
+ self.assertEqual(
+ get_users_for_sync(guild_users, api_users),
+ ({fake_user(id=63)}, set())
+ )
+
+ def test_get_users_for_sync_updates_in_guild_field_on_user_leave(self):
+ """When a user leaves the guild, the `in_guild` flag is updated to `False`."""
+ api_users = {43: fake_user(), 63: fake_user(id=63)}
+ guild_users = {43: fake_user()}
+
+ self.assertEqual(
+ get_users_for_sync(guild_users, api_users),
+ (set(), {fake_user(id=63, in_guild=False)})
+ )
+
+ def test_get_users_for_sync_updates_and_creates_users_as_needed(self):
+ """When one user left and another one was updated, both are returned."""
+ api_users = {43: fake_user()}
+ guild_users = {63: fake_user(id=63)}
+
+ self.assertEqual(
+ get_users_for_sync(guild_users, api_users),
+ ({fake_user(id=63)}, {fake_user(in_guild=False)})
+ )
+
+ def test_get_users_for_sync_does_not_duplicate_update_users(self):
+ """When the API knows a user the guild doesn't, nothing is performed."""
+ api_users = {43: fake_user(in_guild=False)}
+ guild_users = {}
+
+ self.assertEqual(
+ get_users_for_sync(guild_users, api_users),
+ (set(), set())
+ )
diff --git a/tests/bot/cogs/test_antispam.py b/tests/bot/cogs/test_antispam.py
new file mode 100644
index 000000000..ce5472c71
--- /dev/null
+++ b/tests/bot/cogs/test_antispam.py
@@ -0,0 +1,35 @@
+import unittest
+
+from bot.cogs import antispam
+
+
+class AntispamConfigurationValidationTests(unittest.TestCase):
+ """Tests validation of the antispam cog configuration."""
+
+ def test_default_antispam_config_is_valid(self):
+ """The default antispam configuration is valid."""
+ validation_errors = antispam.validate_config()
+ self.assertEqual(validation_errors, {})
+
+ def test_unknown_rule_returns_error(self):
+ """Configuring an unknown rule returns an error."""
+ self.assertEqual(
+ antispam.validate_config({'invalid-rule': {}}),
+ {'invalid-rule': "`invalid-rule` is not recognized as an antispam rule."}
+ )
+
+ def test_missing_keys_returns_error(self):
+ """Not configuring required keys returns an error."""
+ keys = (('interval', 'max'), ('max', 'interval'))
+ for configured_key, unconfigured_key in keys:
+ with self.subTest(
+ configured_key=configured_key,
+ unconfigured_key=unconfigured_key
+ ):
+ config = {'burst': {configured_key: 10}}
+ error = f"Key `{unconfigured_key}` is required but not set for rule `burst`"
+
+ self.assertEqual(
+ antispam.validate_config(config),
+ {'burst': error}
+ )
diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py
new file mode 100644
index 000000000..5c34541d8
--- /dev/null
+++ b/tests/bot/cogs/test_information.py
@@ -0,0 +1,582 @@
+import asyncio
+import textwrap
+import unittest
+import unittest.mock
+
+import discord
+
+from bot import constants
+from bot.cogs import information
+from bot.decorators import InChannelCheckFailure
+from tests import helpers
+
+
+COG_PATH = "bot.cogs.information.Information"
+
+
+class InformationCogTests(unittest.TestCase):
+ """Tests the Information cog."""
+
+ @classmethod
+ def setUpClass(cls):
+ cls.moderator_role = helpers.MockRole(name="Moderator", role_id=constants.Roles.moderator)
+
+ def setUp(self):
+ """Sets up fresh objects for each test."""
+ self.bot = helpers.MockBot()
+
+ self.cog = information.Information(self.bot)
+
+ self.ctx = helpers.MockContext()
+ self.ctx.author.roles.append(self.moderator_role)
+
+ def test_roles_command_command(self):
+ """Test if the `role_info` command correctly returns the `moderator_role`."""
+ self.ctx.guild.roles.append(self.moderator_role)
+
+ self.cog.roles_info.can_run = helpers.AsyncMock()
+ self.cog.roles_info.can_run.return_value = True
+
+ coroutine = self.cog.roles_info.callback(self.cog, self.ctx)
+
+ self.assertIsNone(asyncio.run(coroutine))
+ self.ctx.send.assert_called_once()
+
+ _, kwargs = self.ctx.send.call_args
+ embed = kwargs.pop('embed')
+
+ self.assertEqual(embed.title, "Role information")
+ self.assertEqual(embed.colour, discord.Colour.blurple())
+ self.assertEqual(embed.description, f"`{self.moderator_role.id}` - {self.moderator_role.mention}\n")
+ self.assertEqual(embed.footer.text, "Total roles: 1")
+
+ def test_role_info_command(self):
+ """Tests the `role info` command."""
+ dummy_role = helpers.MockRole(
+ name="Dummy",
+ role_id=112233445566778899,
+ colour=discord.Colour.blurple(),
+ position=10,
+ members=[self.ctx.author],
+ permissions=discord.Permissions(0)
+ )
+
+ admin_role = helpers.MockRole(
+ name="Admins",
+ role_id=998877665544332211,
+ colour=discord.Colour.red(),
+ position=3,
+ members=[self.ctx.author],
+ permissions=discord.Permissions(0),
+ )
+
+ self.ctx.guild.roles.append([dummy_role, admin_role])
+
+ self.cog.role_info.can_run = helpers.AsyncMock()
+ self.cog.role_info.can_run.return_value = True
+
+ coroutine = self.cog.role_info.callback(self.cog, self.ctx, dummy_role, admin_role)
+
+ self.assertIsNone(asyncio.run(coroutine))
+
+ self.assertEqual(self.ctx.send.call_count, 2)
+
+ (_, dummy_kwargs), (_, admin_kwargs) = self.ctx.send.call_args_list
+
+ dummy_embed = dummy_kwargs["embed"]
+ admin_embed = admin_kwargs["embed"]
+
+ self.assertEqual(dummy_embed.title, "Dummy info")
+ self.assertEqual(dummy_embed.colour, discord.Colour.blurple())
+
+ self.assertEqual(dummy_embed.fields[0].value, str(dummy_role.id))
+ self.assertEqual(dummy_embed.fields[1].value, f"#{dummy_role.colour.value:0>6x}")
+ self.assertEqual(dummy_embed.fields[2].value, "0.63 0.48 218")
+ self.assertEqual(dummy_embed.fields[3].value, "1")
+ self.assertEqual(dummy_embed.fields[4].value, "10")
+ self.assertEqual(dummy_embed.fields[5].value, "0")
+
+ self.assertEqual(admin_embed.title, "Admins info")
+ self.assertEqual(admin_embed.colour, discord.Colour.red())
+
+ @unittest.mock.patch('bot.cogs.information.time_since')
+ def test_server_info_command(self, time_since_patch):
+ time_since_patch.return_value = '2 days ago'
+
+ self.ctx.guild = helpers.MockGuild(
+ features=('lemons', 'apples'),
+ region="The Moon",
+ roles=[self.moderator_role],
+ channels=[
+ discord.TextChannel(
+ state={},
+ guild=self.ctx.guild,
+ data={'id': 42, 'name': 'lemons-offering', 'position': 22, 'type': 'text'}
+ ),
+ discord.CategoryChannel(
+ state={},
+ guild=self.ctx.guild,
+ data={'id': 5125, 'name': 'the-lemon-collection', 'position': 22, 'type': 'category'}
+ ),
+ discord.VoiceChannel(
+ state={},
+ guild=self.ctx.guild,
+ data={'id': 15290, 'name': 'listen-to-lemon', 'position': 22, 'type': 'voice'}
+ )
+ ],
+ members=[
+ *(helpers.MockMember(status='online') for _ in range(2)),
+ *(helpers.MockMember(status='idle') for _ in range(1)),
+ *(helpers.MockMember(status='dnd') for _ in range(4)),
+ *(helpers.MockMember(status='offline') for _ in range(3)),
+ ],
+ member_count=1_234,
+ icon_url='a-lemon.jpg',
+ )
+
+ coroutine = self.cog.server_info.callback(self.cog, self.ctx)
+ self.assertIsNone(asyncio.run(coroutine))
+
+ time_since_patch.assert_called_once_with(self.ctx.guild.created_at, precision='days')
+ _, kwargs = self.ctx.send.call_args
+ embed = kwargs.pop('embed')
+ self.assertEqual(embed.colour, discord.Colour.blurple())
+ self.assertEqual(
+ embed.description,
+ textwrap.dedent(
+ f"""
+ **Server information**
+ Created: {time_since_patch.return_value}
+ Voice region: {self.ctx.guild.region}
+ Features: {', '.join(self.ctx.guild.features)}
+
+ **Counts**
+ Members: {self.ctx.guild.member_count:,}
+ Roles: {len(self.ctx.guild.roles)}
+ Text: 1
+ Voice: 1
+ Channel categories: 1
+
+ **Members**
+ {constants.Emojis.status_online} 2
+ {constants.Emojis.status_idle} 1
+ {constants.Emojis.status_dnd} 4
+ {constants.Emojis.status_offline} 3
+ """
+ )
+ )
+ self.assertEqual(embed.thumbnail.url, 'a-lemon.jpg')
+
+
+class UserInfractionHelperMethodTests(unittest.TestCase):
+ """Tests for the helper methods of the `!user` command."""
+
+ def setUp(self):
+ """Common set-up steps done before for each test."""
+ self.bot = helpers.MockBot()
+ self.bot.api_client.get = helpers.AsyncMock()
+ self.cog = information.Information(self.bot)
+ self.member = helpers.MockMember(user_id=1234)
+
+ def test_user_command_helper_method_get_requests(self):
+ """The helper methods should form the correct get requests."""
+ test_values = (
+ {
+ "helper_method": self.cog.basic_user_infraction_counts,
+ "expected_args": ("bot/infractions", {'hidden': 'False', 'user__id': str(self.member.id)}),
+ },
+ {
+ "helper_method": self.cog.expanded_user_infraction_counts,
+ "expected_args": ("bot/infractions", {'user__id': str(self.member.id)}),
+ },
+ {
+ "helper_method": self.cog.user_nomination_counts,
+ "expected_args": ("bot/nominations", {'user__id': str(self.member.id)}),
+ },
+ )
+
+ for test_value in test_values:
+ helper_method = test_value["helper_method"]
+ endpoint, params = test_value["expected_args"]
+
+ with self.subTest(method=helper_method, endpoint=endpoint, params=params):
+ asyncio.run(helper_method(self.member))
+ self.bot.api_client.get.assert_called_once_with(endpoint, params=params)
+ self.bot.api_client.get.reset_mock()
+
+ def _method_subtests(self, method, test_values, default_header):
+ """Helper method that runs the subtests for the different helper methods."""
+ for test_value in test_values:
+ api_response = test_value["api response"]
+ expected_lines = test_value["expected_lines"]
+
+ with self.subTest(method=method, api_response=api_response, expected_lines=expected_lines):
+ self.bot.api_client.get.return_value = api_response
+
+ expected_output = "\n".join(default_header + expected_lines)
+ actual_output = asyncio.run(method(self.member))
+
+ self.assertEqual(expected_output, actual_output)
+
+ def test_basic_user_infraction_counts_returns_correct_strings(self):
+ """The method should correctly list both the total and active number of non-hidden infractions."""
+ test_values = (
+ # No infractions means zero counts
+ {
+ "api response": [],
+ "expected_lines": ["Total: 0", "Active: 0"],
+ },
+ # Simple, single-infraction dictionaries
+ {
+ "api response": [{"type": "ban", "active": True}],
+ "expected_lines": ["Total: 1", "Active: 1"],
+ },
+ {
+ "api response": [{"type": "ban", "active": False}],
+ "expected_lines": ["Total: 1", "Active: 0"],
+ },
+ # Multiple infractions with various `active` status
+ {
+ "api response": [
+ {"type": "ban", "active": True},
+ {"type": "kick", "active": False},
+ {"type": "ban", "active": True},
+ {"type": "ban", "active": False},
+ ],
+ "expected_lines": ["Total: 4", "Active: 2"],
+ },
+ )
+
+ header = ["**Infractions**"]
+
+ self._method_subtests(self.cog.basic_user_infraction_counts, test_values, header)
+
+ def test_expanded_user_infraction_counts_returns_correct_strings(self):
+ """The method should correctly list the total and active number of all infractions split by infraction type."""
+ test_values = (
+ {
+ "api response": [],
+ "expected_lines": ["This user has never received an infraction."],
+ },
+ # Shows non-hidden inactive infraction as expected
+ {
+ "api response": [{"type": "kick", "active": False, "hidden": False}],
+ "expected_lines": ["Kicks: 1"],
+ },
+ # Shows non-hidden active infraction as expected
+ {
+ "api response": [{"type": "mute", "active": True, "hidden": False}],
+ "expected_lines": ["Mutes: 1 (1 active)"],
+ },
+ # Shows hidden inactive infraction as expected
+ {
+ "api response": [{"type": "superstar", "active": False, "hidden": True}],
+ "expected_lines": ["Superstars: 1"],
+ },
+ # Shows hidden active infraction as expected
+ {
+ "api response": [{"type": "ban", "active": True, "hidden": True}],
+ "expected_lines": ["Bans: 1 (1 active)"],
+ },
+ # Correctly displays tally of multiple infractions of mixed properties in alphabetical order
+ {
+ "api response": [
+ {"type": "kick", "active": False, "hidden": True},
+ {"type": "ban", "active": True, "hidden": True},
+ {"type": "superstar", "active": True, "hidden": True},
+ {"type": "mute", "active": True, "hidden": True},
+ {"type": "ban", "active": False, "hidden": False},
+ {"type": "note", "active": False, "hidden": True},
+ {"type": "note", "active": False, "hidden": True},
+ {"type": "warn", "active": False, "hidden": False},
+ {"type": "note", "active": False, "hidden": True},
+ ],
+ "expected_lines": [
+ "Bans: 2 (1 active)",
+ "Kicks: 1",
+ "Mutes: 1 (1 active)",
+ "Notes: 3",
+ "Superstars: 1 (1 active)",
+ "Warns: 1",
+ ],
+ },
+ )
+
+ header = ["**Infractions**"]
+
+ self._method_subtests(self.cog.expanded_user_infraction_counts, test_values, header)
+
+ def test_user_nomination_counts_returns_correct_strings(self):
+ """The method should list the number of active and historical nominations for the user."""
+ test_values = (
+ {
+ "api response": [],
+ "expected_lines": ["This user has never been nominated."],
+ },
+ {
+ "api response": [{'active': True}],
+ "expected_lines": ["This user is **currently** nominated (1 nomination in total)."],
+ },
+ {
+ "api response": [{'active': True}, {'active': False}],
+ "expected_lines": ["This user is **currently** nominated (2 nominations in total)."],
+ },
+ {
+ "api response": [{'active': False}],
+ "expected_lines": ["This user has 1 historical nomination, but is currently not nominated."],
+ },
+ {
+ "api response": [{'active': False}, {'active': False}],
+ "expected_lines": ["This user has 2 historical nominations, but is currently not nominated."],
+ },
+
+ )
+
+ header = ["**Nominations**"]
+
+ self._method_subtests(self.cog.user_nomination_counts, test_values, header)
+
+
[email protected]("bot.cogs.information.time_since", new=unittest.mock.MagicMock(return_value="1 year ago"))
[email protected]("bot.cogs.information.constants.MODERATION_CHANNELS", new=[50])
+class UserEmbedTests(unittest.TestCase):
+ """Tests for the creation of the `!user` embed."""
+
+ def setUp(self):
+ """Common set-up steps done before for each test."""
+ self.bot = helpers.MockBot()
+ self.bot.api_client.get = helpers.AsyncMock()
+ self.cog = information.Information(self.bot)
+
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value=""))
+ def test_create_user_embed_uses_string_representation_of_user_in_title_if_nick_is_not_available(self):
+ """The embed should use the string representation of the user if they don't have a nick."""
+ ctx = helpers.MockContext(channel=helpers.MockTextChannel(channel_id=1))
+ user = helpers.MockMember()
+ user.nick = None
+ user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock")
+
+ embed = asyncio.run(self.cog.create_user_embed(ctx, user))
+
+ self.assertEqual(embed.title, "Mr. Hemlock")
+
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value=""))
+ def test_create_user_embed_uses_nick_in_title_if_available(self):
+ """The embed should use the nick if it's available."""
+ ctx = helpers.MockContext(channel=helpers.MockTextChannel(channel_id=1))
+ user = helpers.MockMember()
+ user.nick = "Cat lover"
+ user.__str__ = unittest.mock.Mock(return_value="Mr. Hemlock")
+
+ embed = asyncio.run(self.cog.create_user_embed(ctx, user))
+
+ self.assertEqual(embed.title, "Cat lover (Mr. Hemlock)")
+
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value=""))
+ def test_create_user_embed_ignores_everyone_role(self):
+ """Created `!user` embeds should not contain mention of the @everyone-role."""
+ ctx = helpers.MockContext(channel=helpers.MockTextChannel(channel_id=1))
+ admins_role = helpers.MockRole('Admins')
+ admins_role.colour = 100
+
+ # A `MockMember` has the @Everyone role by default; we add the Admins to that.
+ user = helpers.MockMember(roles=[admins_role], top_role=admins_role)
+
+ embed = asyncio.run(self.cog.create_user_embed(ctx, user))
+
+ self.assertIn("&Admins", embed.description)
+ self.assertNotIn("&Everyone", embed.description)
+
+ @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=helpers.AsyncMock)
+ @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=helpers.AsyncMock)
+ def test_create_user_embed_expanded_information_in_moderation_channels(self, nomination_counts, infraction_counts):
+ """The embed should contain expanded infractions and nomination info in mod channels."""
+ ctx = helpers.MockContext(channel=helpers.MockTextChannel(channel_id=50))
+
+ moderators_role = helpers.MockRole('Moderators')
+ moderators_role.colour = 100
+
+ infraction_counts.return_value = "expanded infractions info"
+ nomination_counts.return_value = "nomination info"
+
+ user = helpers.MockMember(user_id=314, roles=[moderators_role], top_role=moderators_role)
+ embed = asyncio.run(self.cog.create_user_embed(ctx, user))
+
+ infraction_counts.assert_called_once_with(user)
+ nomination_counts.assert_called_once_with(user)
+
+ self.assertEqual(
+ textwrap.dedent(f"""
+ **User Information**
+ Created: {"1 year ago"}
+ Profile: {user.mention}
+ ID: {user.id}
+
+ **Member Information**
+ Joined: {"1 year ago"}
+ Roles: &Moderators
+
+ expanded infractions info
+
+ nomination info
+ """).strip(),
+ embed.description
+ )
+
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=helpers.AsyncMock)
+ def test_create_user_embed_basic_information_outside_of_moderation_channels(self, infraction_counts):
+ """The embed should contain only basic infraction data outside of mod channels."""
+ ctx = helpers.MockContext(channel=helpers.MockTextChannel(channel_id=100))
+
+ moderators_role = helpers.MockRole('Moderators')
+ moderators_role.colour = 100
+
+ infraction_counts.return_value = "basic infractions info"
+
+ user = helpers.MockMember(user_id=314, roles=[moderators_role], top_role=moderators_role)
+ embed = asyncio.run(self.cog.create_user_embed(ctx, user))
+
+ infraction_counts.assert_called_once_with(user)
+
+ self.assertEqual(
+ textwrap.dedent(f"""
+ **User Information**
+ Created: {"1 year ago"}
+ Profile: {user.mention}
+ ID: {user.id}
+
+ **Member Information**
+ Joined: {"1 year ago"}
+ Roles: &Moderators
+
+ basic infractions info
+ """).strip(),
+ embed.description
+ )
+
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value=""))
+ def test_create_user_embed_uses_top_role_colour_when_user_has_roles(self):
+ """The embed should be created with the colour of the top role, if a top role is available."""
+ ctx = helpers.MockContext()
+
+ moderators_role = helpers.MockRole('Moderators')
+ moderators_role.colour = 100
+
+ user = helpers.MockMember(user_id=314, roles=[moderators_role], top_role=moderators_role)
+ embed = asyncio.run(self.cog.create_user_embed(ctx, user))
+
+ self.assertEqual(embed.colour, discord.Colour(moderators_role.colour))
+
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value=""))
+ def test_create_user_embed_uses_blurple_colour_when_user_has_no_roles(self):
+ """The embed should be created with a blurple colour if the user has no assigned roles."""
+ ctx = helpers.MockContext()
+
+ user = helpers.MockMember(user_id=217)
+ embed = asyncio.run(self.cog.create_user_embed(ctx, user))
+
+ self.assertEqual(embed.colour, discord.Colour.blurple())
+
+ @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=helpers.AsyncMock(return_value=""))
+ def test_create_user_embed_uses_png_format_of_user_avatar_as_thumbnail(self):
+ """The embed thumbnail should be set to the user's avatar in `png` format."""
+ ctx = helpers.MockContext()
+
+ user = helpers.MockMember(user_id=217)
+ user.avatar_url_as.return_value = "avatar url"
+ embed = asyncio.run(self.cog.create_user_embed(ctx, user))
+
+ user.avatar_url_as.assert_called_once_with(format="png")
+ self.assertEqual(embed.thumbnail.url, "avatar url")
+
+
[email protected]("bot.cogs.information.constants")
+class UserCommandTests(unittest.TestCase):
+ """Tests for the `!user` command."""
+
+ def setUp(self):
+ """Set up steps executed before each test is run."""
+ self.bot = helpers.MockBot()
+ self.cog = information.Information(self.bot)
+
+ self.moderator_role = helpers.MockRole("Moderators", role_id=2, position=10)
+ self.flautist_role = helpers.MockRole("Flautists", role_id=3, position=2)
+ self.bassist_role = helpers.MockRole("Bassists", role_id=4, position=3)
+
+ self.author = helpers.MockMember(user_id=1, name="syntaxaire")
+ self.moderator = helpers.MockMember(user_id=2, name="riffautae", roles=[self.moderator_role])
+ self.target = helpers.MockMember(user_id=3, name="__fluzz__")
+
+ def test_regular_member_cannot_target_another_member(self, constants):
+ """A regular user should not be able to use `!user` targeting another user."""
+ constants.MODERATION_ROLES = [self.moderator_role.id]
+
+ ctx = helpers.MockContext(author=self.author)
+
+ asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.target))
+
+ ctx.send.assert_called_once_with("You may not use this command on users other than yourself.")
+
+ def test_regular_member_cannot_use_command_outside_of_bot_commands(self, constants):
+ """A regular user should not be able to use this command outside of bot-commands."""
+ constants.MODERATION_ROLES = [self.moderator_role.id]
+ constants.STAFF_ROLES = [self.moderator_role.id]
+ constants.Channels.bot = 50
+
+ ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(channel_id=100))
+
+ msg = "Sorry, but you may only use this command within <#50>."
+ with self.assertRaises(InChannelCheckFailure, msg=msg):
+ asyncio.run(self.cog.user_info.callback(self.cog, ctx))
+
+ @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock)
+ def test_regular_user_may_use_command_in_bot_commands_channel(self, create_embed, constants):
+ """A regular user should be allowed to use `!user` targeting themselves in bot-commands."""
+ constants.STAFF_ROLES = [self.moderator_role.id]
+ constants.Channels.bot = 50
+
+ ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(channel_id=50))
+
+ asyncio.run(self.cog.user_info.callback(self.cog, ctx))
+
+ create_embed.assert_called_once_with(ctx, self.author)
+ ctx.send.assert_called_once()
+
+ @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock)
+ def test_regular_user_can_explicitly_target_themselves(self, create_embed, constants):
+ """A user should target itself with `!user` when a `user` argument was not provided."""
+ constants.STAFF_ROLES = [self.moderator_role.id]
+ constants.Channels.bot = 50
+
+ ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(channel_id=50))
+
+ asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.author))
+
+ create_embed.assert_called_once_with(ctx, self.author)
+ ctx.send.assert_called_once()
+
+ @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock)
+ def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants):
+ """Staff members should be able to bypass the bot-commands channel restriction."""
+ constants.STAFF_ROLES = [self.moderator_role.id]
+ constants.Channels.bot = 50
+
+ ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(channel_id=200))
+
+ asyncio.run(self.cog.user_info.callback(self.cog, ctx))
+
+ create_embed.assert_called_once_with(ctx, self.moderator)
+ ctx.send.assert_called_once()
+
+ @unittest.mock.patch("bot.cogs.information.Information.create_user_embed", new_callable=helpers.AsyncMock)
+ def test_moderators_can_target_another_member(self, create_embed, constants):
+ """A moderator should be able to use `!user` targeting another user."""
+ constants.MODERATION_ROLES = [self.moderator_role.id]
+ constants.STAFF_ROLES = [self.moderator_role.id]
+
+ ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(channel_id=50))
+
+ asyncio.run(self.cog.user_info.callback(self.cog, ctx, self.target))
+
+ create_embed.assert_called_once_with(ctx, self.target)
+ ctx.send.assert_called_once()
diff --git a/tests/bot/cogs/test_security.py b/tests/bot/cogs/test_security.py
new file mode 100644
index 000000000..efa7a50b1
--- /dev/null
+++ b/tests/bot/cogs/test_security.py
@@ -0,0 +1,59 @@
+import logging
+import unittest
+from unittest.mock import MagicMock
+
+from discord.ext.commands import NoPrivateMessage
+
+from bot.cogs import security
+from tests.helpers import MockBot, MockContext
+
+
+class SecurityCogTests(unittest.TestCase):
+ """Tests the `Security` cog."""
+
+ def setUp(self):
+ """Attach an instance of the cog to the class for tests."""
+ self.bot = MockBot()
+ self.cog = security.Security(self.bot)
+ self.ctx = MockContext()
+
+ def test_check_additions(self):
+ """The cog should add its checks after initialization."""
+ self.bot.check.assert_any_call(self.cog.check_on_guild)
+ self.bot.check.assert_any_call(self.cog.check_not_bot)
+
+ def test_check_not_bot_returns_false_for_humans(self):
+ """The bot check should return `True` when invoked with human authors."""
+ self.ctx.author.bot = False
+ self.assertTrue(self.cog.check_not_bot(self.ctx))
+
+ def test_check_not_bot_returns_true_for_robots(self):
+ """The bot check should return `False` when invoked with robotic authors."""
+ self.ctx.author.bot = True
+ self.assertFalse(self.cog.check_not_bot(self.ctx))
+
+ def test_check_on_guild_raises_when_outside_of_guild(self):
+ """When invoked outside of a guild, `check_on_guild` should cause an error."""
+ self.ctx.guild = None
+
+ with self.assertRaises(NoPrivateMessage, msg="This command cannot be used in private messages."):
+ self.cog.check_on_guild(self.ctx)
+
+ def test_check_on_guild_returns_true_inside_of_guild(self):
+ """When invoked inside of a guild, `check_on_guild` should return `True`."""
+ self.ctx.guild = "lemon's lemonade stand"
+ self.assertTrue(self.cog.check_on_guild(self.ctx))
+
+
+class SecurityCogLoadTests(unittest.TestCase):
+ """Tests loading the `Security` cog."""
+
+ def test_security_cog_load(self):
+ """Cog loading logs a message at `INFO` level."""
+ bot = MagicMock()
+ with self.assertLogs(logger='bot.cogs.security', level=logging.INFO) as cm:
+ security.setup(bot)
+ bot.add_cog.assert_called_once()
+
+ [line] = cm.output
+ self.assertIn("Cog loaded: Security", line)
diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py
new file mode 100644
index 000000000..dfb1bafc9
--- /dev/null
+++ b/tests/bot/cogs/test_token_remover.py
@@ -0,0 +1,135 @@
+import asyncio
+import logging
+import unittest
+from unittest.mock import MagicMock
+
+from discord import Colour
+
+from bot.cogs.token_remover import (
+ DELETION_MESSAGE_TEMPLATE,
+ TokenRemover,
+ setup as setup_cog,
+)
+from bot.constants import Channels, Colours, Event, Icons
+from tests.helpers import AsyncMock, MockBot, MockMessage
+
+
+class TokenRemoverTests(unittest.TestCase):
+ """Tests the `TokenRemover` cog."""
+
+ def setUp(self):
+ """Adds the cog, a bot, and a message to the instance for usage in tests."""
+ self.bot = MockBot()
+ self.bot.get_cog.return_value = MagicMock()
+ self.bot.get_cog.return_value.send_log_message = AsyncMock()
+ self.cog = TokenRemover(bot=self.bot)
+
+ self.msg = MockMessage(message_id=555, content='')
+ self.msg.author.__str__ = MagicMock()
+ self.msg.author.__str__.return_value = 'lemon'
+ self.msg.author.bot = False
+ self.msg.author.avatar_url_as.return_value = 'picture-lemon.png'
+ self.msg.author.id = 42
+ self.msg.author.mention = '@lemon'
+ self.msg.channel.mention = "#lemonade-stand"
+
+ def test_is_valid_user_id_is_true_for_numeric_content(self):
+ """A string decoding to numeric characters is a valid user ID."""
+ # MTIz = base64(123)
+ self.assertTrue(TokenRemover.is_valid_user_id('MTIz'))
+
+ def test_is_valid_user_id_is_false_for_alphabetic_content(self):
+ """A string decoding to alphabetic characters is not a valid user ID."""
+ # YWJj = base64(abc)
+ self.assertFalse(TokenRemover.is_valid_user_id('YWJj'))
+
+ def test_is_valid_timestamp_is_true_for_valid_timestamps(self):
+ """A string decoding to a valid timestamp should be recognized as such."""
+ self.assertTrue(TokenRemover.is_valid_timestamp('DN9r_A'))
+
+ def test_is_valid_timestamp_is_false_for_invalid_values(self):
+ """A string not decoding to a valid timestamp should not be recognized as such."""
+ # MTIz = base64(123)
+ self.assertFalse(TokenRemover.is_valid_timestamp('MTIz'))
+
+ def test_mod_log_property(self):
+ """The `mod_log` property should ask the bot to return the `ModLog` cog."""
+ self.bot.get_cog.return_value = 'lemon'
+ self.assertEqual(self.cog.mod_log, self.bot.get_cog.return_value)
+ self.bot.get_cog.assert_called_once_with('ModLog')
+
+ def test_ignores_bot_messages(self):
+ """When the message event handler is called with a bot message, nothing is done."""
+ self.msg.author.bot = True
+ coroutine = self.cog.on_message(self.msg)
+ self.assertIsNone(asyncio.run(coroutine))
+
+ def test_ignores_messages_without_tokens(self):
+ """Messages without anything looking like a token are ignored."""
+ for content in ('', 'lemon wins'):
+ with self.subTest(content=content):
+ self.msg.content = content
+ coroutine = self.cog.on_message(self.msg)
+ self.assertIsNone(asyncio.run(coroutine))
+
+ def test_ignores_messages_with_invalid_tokens(self):
+ """Messages with values that are invalid tokens are ignored."""
+ for content in ('foo.bar.baz', 'x.y.'):
+ with self.subTest(content=content):
+ self.msg.content = content
+ coroutine = self.cog.on_message(self.msg)
+ self.assertIsNone(asyncio.run(coroutine))
+
+ def test_censors_valid_tokens(self):
+ """Valid tokens are censored."""
+ cases = (
+ # (content, censored_token)
+ ('MTIz.DN9R_A.xyz', 'MTIz.DN9R_A.xxx'),
+ )
+
+ for content, censored_token in cases:
+ with self.subTest(content=content, censored_token=censored_token):
+ self.msg.content = content
+ coroutine = self.cog.on_message(self.msg)
+ with self.assertLogs(logger='bot.cogs.token_remover', level=logging.DEBUG) as cm:
+ self.assertIsNone(asyncio.run(coroutine)) # no return value
+
+ [line] = cm.output
+ log_message = (
+ "Censored a seemingly valid token sent by "
+ "lemon (`42`) in #lemonade-stand, "
+ f"token was `{censored_token}`"
+ )
+ self.assertIn(log_message, line)
+
+ self.msg.delete.assert_called_once_with()
+ self.msg.channel.send.assert_called_once_with(
+ DELETION_MESSAGE_TEMPLATE.format(mention='@lemon')
+ )
+ self.bot.get_cog.assert_called_with('ModLog')
+ self.msg.author.avatar_url_as.assert_called_once_with(static_format='png')
+
+ mod_log = self.bot.get_cog.return_value
+ mod_log.ignore.assert_called_once_with(Event.message_delete, self.msg.id)
+ mod_log.send_log_message.assert_called_once_with(
+ icon_url=Icons.token_removed,
+ colour=Colour(Colours.soft_red),
+ title="Token removed!",
+ text=log_message,
+ thumbnail='picture-lemon.png',
+ channel_id=Channels.mod_alerts
+ )
+
+
+class TokenRemoverSetupTests(unittest.TestCase):
+ """Tests setup of the `TokenRemover` cog."""
+
+ def test_setup(self):
+ """Setup of the cog should log a message at `INFO` level."""
+ bot = MockBot()
+ with self.assertLogs(logger='bot.cogs.token_remover', level=logging.INFO) as cm:
+ setup_cog(bot)
+
+ [line] = cm.output
+ bot.add_cog.assert_called_once()
+ self.assertIn("Cog loaded: TokenRemover", line)
diff --git a/tests/utils/__init__.py b/tests/bot/patches/__init__.py
index e69de29bb..e69de29bb 100644
--- a/tests/utils/__init__.py
+++ b/tests/bot/patches/__init__.py
diff --git a/tests/bot/resources/__init__.py b/tests/bot/resources/__init__.py
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/tests/bot/resources/__init__.py
diff --git a/tests/bot/resources/test_resources.py b/tests/bot/resources/test_resources.py
new file mode 100644
index 000000000..73937cfa6
--- /dev/null
+++ b/tests/bot/resources/test_resources.py
@@ -0,0 +1,17 @@
+import json
+import unittest
+from pathlib import Path
+
+
+class ResourceValidationTests(unittest.TestCase):
+ """Validates resources used by the bot."""
+ def test_stars_valid(self):
+ """The resource `bot/resources/stars.json` should contain a list of strings."""
+ path = Path('bot', 'resources', 'stars.json')
+ content = path.read_text()
+ data = json.loads(content)
+
+ self.assertIsInstance(data, list)
+ for name in data:
+ with self.subTest(name=name):
+ self.assertIsInstance(name, str)
diff --git a/tests/bot/rules/__init__.py b/tests/bot/rules/__init__.py
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/tests/bot/rules/__init__.py
diff --git a/tests/bot/rules/test_attachments.py b/tests/bot/rules/test_attachments.py
new file mode 100644
index 000000000..4bb0acf7c
--- /dev/null
+++ b/tests/bot/rules/test_attachments.py
@@ -0,0 +1,52 @@
+import asyncio
+import unittest
+from dataclasses import dataclass
+from typing import Any, List
+
+from bot.rules import attachments
+
+
+# Using `MagicMock` sadly doesn't work for this usecase
+# since it's __eq__ compares the MagicMock's ID. We just
+# want to compare the actual attributes we set.
+@dataclass
+class FakeMessage:
+ author: str
+ attachments: List[Any]
+
+
+def msg(total_attachments: int) -> FakeMessage:
+ return FakeMessage(author='lemon', attachments=list(range(total_attachments)))
+
+
+class AttachmentRuleTests(unittest.TestCase):
+ """Tests applying the `attachment` antispam rule."""
+
+ def test_allows_messages_without_too_many_attachments(self):
+ """Messages without too many attachments are allowed as-is."""
+ cases = (
+ (msg(0), msg(0), msg(0)),
+ (msg(2), msg(2)),
+ (msg(0),),
+ )
+
+ for last_message, *recent_messages in cases:
+ with self.subTest(last_message=last_message, recent_messages=recent_messages):
+ coro = attachments.apply(last_message, recent_messages, {'max': 5})
+ self.assertIsNone(asyncio.run(coro))
+
+ def test_disallows_messages_with_too_many_attachments(self):
+ """Messages with too many attachments trigger the rule."""
+ cases = (
+ ((msg(4), msg(0), msg(6)), [msg(4), msg(6)], 10),
+ ((msg(6),), [msg(6)], 6),
+ ((msg(1),) * 6, [msg(1)] * 6, 6),
+ )
+ for messages, relevant_messages, total in cases:
+ with self.subTest(messages=messages, relevant_messages=relevant_messages, total=total):
+ last_message, *recent_messages = messages
+ coro = attachments.apply(last_message, recent_messages, {'max': 5})
+ self.assertEqual(
+ asyncio.run(coro),
+ (f"sent {total} attachments in 5s", ('lemon',), relevant_messages)
+ )
diff --git a/tests/bot/test_api.py b/tests/bot/test_api.py
new file mode 100644
index 000000000..e0ede0eb1
--- /dev/null
+++ b/tests/bot/test_api.py
@@ -0,0 +1,134 @@
+import logging
+import unittest
+from unittest.mock import MagicMock, patch
+
+from bot import api
+from tests.base import LoggingTestCase
+from tests.helpers import async_test
+
+
+class APIClientTests(unittest.TestCase):
+ """Tests for the bot's API client."""
+
+ @classmethod
+ def setUpClass(cls):
+ """Sets up the shared fixtures for the tests."""
+ cls.error_api_response = MagicMock()
+ cls.error_api_response.status = 999
+
+ def test_loop_is_not_running_by_default(self):
+ """The event loop should not be running by default."""
+ self.assertFalse(api.loop_is_running())
+
+ @async_test
+ async def test_loop_is_running_in_async_context(self):
+ """The event loop should be running in an async context."""
+ self.assertTrue(api.loop_is_running())
+
+ def test_response_code_error_default_initialization(self):
+ """Test the default initialization of `ResponseCodeError` without `text` or `json`"""
+ error = api.ResponseCodeError(response=self.error_api_response)
+
+ self.assertIs(error.status, self.error_api_response.status)
+ self.assertEqual(error.response_json, {})
+ self.assertEqual(error.response_text, "")
+ self.assertIs(error.response, self.error_api_response)
+
+ def test_responde_code_error_string_representation_default_initialization(self):
+ """Test the string representation of `ResponseCodeError` initialized without text or json."""
+ error = api.ResponseCodeError(response=self.error_api_response)
+ self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: ")
+
+ def test_response_code_error_initialization_with_json(self):
+ """Test the initialization of `ResponseCodeError` with json."""
+ json_data = {'hello': 'world'}
+ error = api.ResponseCodeError(
+ response=self.error_api_response,
+ response_json=json_data,
+ )
+ self.assertEqual(error.response_json, json_data)
+ self.assertEqual(error.response_text, "")
+
+ def test_response_code_error_string_representation_with_nonempty_response_json(self):
+ """Test the string representation of `ResponseCodeError` initialized with json."""
+ json_data = {'hello': 'world'}
+ error = api.ResponseCodeError(
+ response=self.error_api_response,
+ response_json=json_data
+ )
+ self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: {json_data}")
+
+ def test_response_code_error_initialization_with_text(self):
+ """Test the initialization of `ResponseCodeError` with text."""
+ text_data = 'Lemon will eat your soul'
+ error = api.ResponseCodeError(
+ response=self.error_api_response,
+ response_text=text_data,
+ )
+ self.assertEqual(error.response_text, text_data)
+ self.assertEqual(error.response_json, {})
+
+ def test_response_code_error_string_representation_with_nonempty_response_text(self):
+ """Test the string representation of `ResponseCodeError` initialized with text."""
+ text_data = 'Lemon will eat your soul'
+ error = api.ResponseCodeError(
+ response=self.error_api_response,
+ response_text=text_data
+ )
+ self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: {text_data}")
+
+
+class LoggingHandlerTests(LoggingTestCase):
+ """Tests the bot's API Log Handler."""
+
+ @classmethod
+ def setUpClass(cls):
+ cls.debug_log_record = logging.LogRecord(
+ name='my.logger', level=logging.DEBUG,
+ pathname='my/logger.py', lineno=666,
+ msg="Lemon wins", args=(),
+ exc_info=None
+ )
+
+ cls.trace_log_record = logging.LogRecord(
+ name='my.logger', level=logging.TRACE,
+ pathname='my/logger.py', lineno=666,
+ msg="This will not be logged", args=(),
+ exc_info=None
+ )
+
+ def setUp(self):
+ self.log_handler = api.APILoggingHandler(None)
+
+ def test_emit_appends_to_queue_with_stopped_event_loop(self):
+ """Test if `APILoggingHandler.emit` appends to queue when the event loop is not running."""
+ with patch("bot.api.APILoggingHandler.ship_off") as ship_off:
+ # Patch `ship_off` to ease testing against the return value of this coroutine.
+ ship_off.return_value = 42
+ self.log_handler.emit(self.debug_log_record)
+
+ self.assertListEqual(self.log_handler.queue, [42])
+
+ def test_emit_ignores_less_than_debug(self):
+ """`APILoggingHandler.emit` should not queue logs with a log level lower than DEBUG."""
+ self.log_handler.emit(self.trace_log_record)
+ self.assertListEqual(self.log_handler.queue, [])
+
+ def test_schedule_queued_tasks_for_empty_queue(self):
+ """`APILoggingHandler` should not schedule anything when the queue is empty."""
+ with self.assertNotLogs(level=logging.DEBUG):
+ self.log_handler.schedule_queued_tasks()
+
+ def test_schedule_queued_tasks_for_nonempty_queue(self):
+ """`APILoggingHandler` should schedule logs when the queue is not empty."""
+ with self.assertLogs(level=logging.DEBUG) as logs, patch('asyncio.create_task') as create_task:
+ self.log_handler.queue = [555]
+ self.log_handler.schedule_queued_tasks()
+ self.assertListEqual(self.log_handler.queue, [])
+ create_task.assert_called_once_with(555)
+
+ [record] = logs.records
+ self.assertEqual(record.message, "Scheduled 1 pending logging tasks.")
+ self.assertEqual(record.levelno, logging.DEBUG)
+ self.assertEqual(record.name, 'bot.api')
+ self.assertIn('via_handler', record.__dict__)
diff --git a/tests/bot/test_constants.py b/tests/bot/test_constants.py
new file mode 100644
index 000000000..dae7c066c
--- /dev/null
+++ b/tests/bot/test_constants.py
@@ -0,0 +1,26 @@
+import inspect
+import unittest
+
+from bot import constants
+
+
+class ConstantsTests(unittest.TestCase):
+ """Tests for our constants."""
+
+ def test_section_configuration_matches_type_specification(self):
+ """The section annotations should match the actual types of the sections."""
+
+ sections = (
+ cls
+ for (name, cls) in inspect.getmembers(constants)
+ if hasattr(cls, 'section') and isinstance(cls, type)
+ )
+ for section in sections:
+ for name, annotation in section.__annotations__.items():
+ with self.subTest(section=section, name=name, annotation=annotation):
+ value = getattr(section, name)
+
+ if getattr(annotation, '_name', None) in ('Dict', 'List'):
+ self.skipTest("Cannot validate containers yet.")
+
+ self.assertIsInstance(value, annotation)
diff --git a/tests/bot/test_converters.py b/tests/bot/test_converters.py
new file mode 100644
index 000000000..b2b78d9dd
--- /dev/null
+++ b/tests/bot/test_converters.py
@@ -0,0 +1,273 @@
+import asyncio
+import datetime
+import unittest
+from unittest.mock import MagicMock, patch
+
+from dateutil.relativedelta import relativedelta
+from discord.ext.commands import BadArgument
+
+from bot.converters import (
+ Duration,
+ ISODateTime,
+ TagContentConverter,
+ TagNameConverter,
+ ValidPythonIdentifier,
+)
+
+
+class ConverterTests(unittest.TestCase):
+ """Tests our custom argument converters."""
+
+ @classmethod
+ def setUpClass(cls):
+ cls.context = MagicMock
+ cls.context.author = 'bob'
+
+ cls.fixed_utc_now = datetime.datetime.fromisoformat('2019-01-01T00:00:00')
+
+ def test_tag_content_converter_for_valid(self):
+ """TagContentConverter should return correct values for valid input."""
+ test_values = (
+ ('hello', 'hello'),
+ (' h ello ', 'h ello'),
+ )
+
+ for content, expected_conversion in test_values:
+ with self.subTest(content=content, expected_conversion=expected_conversion):
+ conversion = asyncio.run(TagContentConverter.convert(self.context, content))
+ self.assertEqual(conversion, expected_conversion)
+
+ def test_tag_content_converter_for_invalid(self):
+ """TagContentConverter should raise the proper exception for invalid input."""
+ test_values = (
+ ('', "Tag contents should not be empty, or filled with whitespace."),
+ (' ', "Tag contents should not be empty, or filled with whitespace."),
+ )
+
+ for value, exception_message in test_values:
+ with self.subTest(tag_content=value, exception_message=exception_message):
+ with self.assertRaises(BadArgument, msg=exception_message):
+ asyncio.run(TagContentConverter.convert(self.context, value))
+
+ def test_tag_name_converter_for_valid(self):
+ """TagNameConverter should return the correct values for valid tag names."""
+ test_values = (
+ ('tracebacks', 'tracebacks'),
+ ('Tracebacks', 'tracebacks'),
+ (' Tracebacks ', 'tracebacks'),
+ )
+
+ for name, expected_conversion in test_values:
+ with self.subTest(name=name, expected_conversion=expected_conversion):
+ conversion = asyncio.run(TagNameConverter.convert(self.context, name))
+ self.assertEqual(conversion, expected_conversion)
+
+ def test_tag_name_converter_for_invalid(self):
+ """TagNameConverter should raise the correct exception for invalid tag names."""
+ test_values = (
+ ('👋', "Don't be ridiculous, you can't use that character!"),
+ ('', "Tag names should not be empty, or filled with whitespace."),
+ (' ', "Tag names should not be empty, or filled with whitespace."),
+ ('42', "Tag names can't be numbers."),
+ ('x' * 128, "Are you insane? That's way too long!"),
+ )
+
+ for invalid_name, exception_message in test_values:
+ with self.subTest(invalid_name=invalid_name, exception_message=exception_message):
+ with self.assertRaises(BadArgument, msg=exception_message):
+ asyncio.run(TagNameConverter.convert(self.context, invalid_name))
+
+ def test_valid_python_identifier_for_valid(self):
+ """ValidPythonIdentifier returns valid identifiers unchanged."""
+ test_values = ('foo', 'lemon')
+
+ for name in test_values:
+ with self.subTest(identifier=name):
+ conversion = asyncio.run(ValidPythonIdentifier.convert(self.context, name))
+ self.assertEqual(name, conversion)
+
+ def test_valid_python_identifier_for_invalid(self):
+ """ValidPythonIdentifier raises the proper exception for invalid identifiers."""
+ test_values = ('nested.stuff', '#####')
+
+ for name in test_values:
+ with self.subTest(identifier=name):
+ exception_message = f'`{name}` is not a valid Python identifier'
+ with self.assertRaises(BadArgument, msg=exception_message):
+ asyncio.run(ValidPythonIdentifier.convert(self.context, name))
+
+ def test_duration_converter_for_valid(self):
+ """Duration returns the correct `datetime` for valid duration strings."""
+ test_values = (
+ # Simple duration strings
+ ('1Y', {"years": 1}),
+ ('1y', {"years": 1}),
+ ('1year', {"years": 1}),
+ ('1years', {"years": 1}),
+ ('1m', {"months": 1}),
+ ('1month', {"months": 1}),
+ ('1months', {"months": 1}),
+ ('1w', {"weeks": 1}),
+ ('1W', {"weeks": 1}),
+ ('1week', {"weeks": 1}),
+ ('1weeks', {"weeks": 1}),
+ ('1d', {"days": 1}),
+ ('1D', {"days": 1}),
+ ('1day', {"days": 1}),
+ ('1days', {"days": 1}),
+ ('1h', {"hours": 1}),
+ ('1H', {"hours": 1}),
+ ('1hour', {"hours": 1}),
+ ('1hours', {"hours": 1}),
+ ('1M', {"minutes": 1}),
+ ('1minute', {"minutes": 1}),
+ ('1minutes', {"minutes": 1}),
+ ('1s', {"seconds": 1}),
+ ('1S', {"seconds": 1}),
+ ('1second', {"seconds": 1}),
+ ('1seconds', {"seconds": 1}),
+
+ # Complex duration strings
+ (
+ '1y1m1w1d1H1M1S',
+ {
+ "years": 1,
+ "months": 1,
+ "weeks": 1,
+ "days": 1,
+ "hours": 1,
+ "minutes": 1,
+ "seconds": 1
+ }
+ ),
+ ('5y100S', {"years": 5, "seconds": 100}),
+ ('2w28H', {"weeks": 2, "hours": 28}),
+
+ # Duration strings with spaces
+ ('1 year 2 months', {"years": 1, "months": 2}),
+ ('1d 2H', {"days": 1, "hours": 2}),
+ ('1 week2 days', {"weeks": 1, "days": 2}),
+ )
+
+ converter = Duration()
+
+ for duration, duration_dict in test_values:
+ expected_datetime = self.fixed_utc_now + relativedelta(**duration_dict)
+
+ with patch('bot.converters.datetime') as mock_datetime:
+ mock_datetime.utcnow.return_value = self.fixed_utc_now
+
+ with self.subTest(duration=duration, duration_dict=duration_dict):
+ converted_datetime = asyncio.run(converter.convert(self.context, duration))
+ self.assertEqual(converted_datetime, expected_datetime)
+
+ def test_duration_converter_for_invalid(self):
+ """Duration raises the right exception for invalid duration strings."""
+ test_values = (
+ # Units in wrong order
+ ('1d1w'),
+ ('1s1y'),
+
+ # Duplicated units
+ ('1 year 2 years'),
+ ('1 M 10 minutes'),
+
+ # Unknown substrings
+ ('1MVes'),
+ ('1y3breads'),
+
+ # Missing amount
+ ('ym'),
+
+ # Incorrect whitespace
+ (" 1y"),
+ ("1S "),
+ ("1y 1m"),
+
+ # Garbage
+ ('Guido van Rossum'),
+ ('lemon lemon lemon lemon lemon lemon lemon'),
+ )
+
+ converter = Duration()
+
+ for invalid_duration in test_values:
+ with self.subTest(invalid_duration=invalid_duration):
+ exception_message = f'`{invalid_duration}` is not a valid duration string.'
+ with self.assertRaises(BadArgument, msg=exception_message):
+ asyncio.run(converter.convert(self.context, invalid_duration))
+
+ def test_isodatetime_converter_for_valid(self):
+ """ISODateTime converter returns correct datetime for valid datetime string."""
+ test_values = (
+ # `YYYY-mm-ddTHH:MM:SSZ` | `YYYY-mm-dd HH:MM:SSZ`
+ ('2019-09-02T02:03:05Z', datetime.datetime(2019, 9, 2, 2, 3, 5)),
+ ('2019-09-02 02:03:05Z', datetime.datetime(2019, 9, 2, 2, 3, 5)),
+
+ # `YYYY-mm-ddTHH:MM:SS±HH:MM` | `YYYY-mm-dd HH:MM:SS±HH:MM`
+ ('2019-09-02T03:18:05+01:15', datetime.datetime(2019, 9, 2, 2, 3, 5)),
+ ('2019-09-02 03:18:05+01:15', datetime.datetime(2019, 9, 2, 2, 3, 5)),
+ ('2019-09-02T00:48:05-01:15', datetime.datetime(2019, 9, 2, 2, 3, 5)),
+ ('2019-09-02 00:48:05-01:15', datetime.datetime(2019, 9, 2, 2, 3, 5)),
+
+ # `YYYY-mm-ddTHH:MM:SS±HHMM` | `YYYY-mm-dd HH:MM:SS±HHMM`
+ ('2019-09-02T03:18:05+0115', datetime.datetime(2019, 9, 2, 2, 3, 5)),
+ ('2019-09-02 03:18:05+0115', datetime.datetime(2019, 9, 2, 2, 3, 5)),
+ ('2019-09-02T00:48:05-0115', datetime.datetime(2019, 9, 2, 2, 3, 5)),
+ ('2019-09-02 00:48:05-0115', datetime.datetime(2019, 9, 2, 2, 3, 5)),
+
+ # `YYYY-mm-ddTHH:MM:SS±HH` | `YYYY-mm-dd HH:MM:SS±HH`
+ ('2019-09-02 03:03:05+01', datetime.datetime(2019, 9, 2, 2, 3, 5)),
+ ('2019-09-02T01:03:05-01', datetime.datetime(2019, 9, 2, 2, 3, 5)),
+
+ # `YYYY-mm-ddTHH:MM:SS` | `YYYY-mm-dd HH:MM:SS`
+ ('2019-09-02T02:03:05', datetime.datetime(2019, 9, 2, 2, 3, 5)),
+ ('2019-09-02 02:03:05', datetime.datetime(2019, 9, 2, 2, 3, 5)),
+
+ # `YYYY-mm-ddTHH:MM` | `YYYY-mm-dd HH:MM`
+ ('2019-11-12T09:15', datetime.datetime(2019, 11, 12, 9, 15)),
+ ('2019-11-12 09:15', datetime.datetime(2019, 11, 12, 9, 15)),
+
+ # `YYYY-mm-dd`
+ ('2019-04-01', datetime.datetime(2019, 4, 1)),
+
+ # `YYYY-mm`
+ ('2019-02-01', datetime.datetime(2019, 2, 1)),
+
+ # `YYYY`
+ ('2025', datetime.datetime(2025, 1, 1)),
+ )
+
+ converter = ISODateTime()
+
+ for datetime_string, expected_dt in test_values:
+ with self.subTest(datetime_string=datetime_string, expected_dt=expected_dt):
+ converted_dt = asyncio.run(converter.convert(self.context, datetime_string))
+ self.assertIsNone(converted_dt.tzinfo)
+ self.assertEqual(converted_dt, expected_dt)
+
+ def test_isodatetime_converter_for_invalid(self):
+ """ISODateTime converter raises the correct exception for invalid datetime strings."""
+ test_values = (
+ # Make sure it doesn't interfere with the Duration converter
+ ('1Y'),
+ ('1d'),
+ ('1H'),
+
+ # Check if it fails when only providing the optional time part
+ ('10:10:10'),
+ ('10:00'),
+
+ # Invalid date format
+ ('19-01-01'),
+
+ # Other non-valid strings
+ ('fisk the tag master'),
+ )
+
+ converter = ISODateTime()
+ for datetime_string in test_values:
+ with self.subTest(datetime_string=datetime_string):
+ exception_message = f"`{datetime_string}` is not a valid ISO-8601 datetime string"
+ with self.assertRaises(BadArgument, msg=exception_message):
+ asyncio.run(converter.convert(self.context, datetime_string))
diff --git a/tests/test_pagination.py b/tests/bot/test_pagination.py
index 11d6541ae..0a734b505 100644
--- a/tests/test_pagination.py
+++ b/tests/bot/test_pagination.py
@@ -1,28 +1,35 @@
from unittest import TestCase
-import pytest
-
from bot import pagination
class LinePaginatorTests(TestCase):
+ """Tests functionality of the `LinePaginator`."""
+
def setUp(self):
+ """Create a paginator for the test method."""
self.paginator = pagination.LinePaginator(prefix='', suffix='', max_size=30)
def test_add_line_raises_on_too_long_lines(self):
+ """`add_line` should raise a `RuntimeError` for too long lines."""
message = f"Line exceeds maximum page size {self.paginator.max_size - 2}"
- with pytest.raises(RuntimeError, match=message):
+ with self.assertRaises(RuntimeError, msg=message):
self.paginator.add_line('x' * self.paginator.max_size)
def test_add_line_works_on_small_lines(self):
+ """`add_line` should allow small lines to be added."""
self.paginator.add_line('x' * (self.paginator.max_size - 3))
class ImagePaginatorTests(TestCase):
+ """Tests functionality of the `ImagePaginator`."""
+
def setUp(self):
+ """Create a paginator for the test method."""
self.paginator = pagination.ImagePaginator()
def test_add_image_appends_image(self):
+ """`add_image` appends the image to the image list."""
image = 'lemon'
self.paginator.add_image(image)
diff --git a/tests/bot/test_utils.py b/tests/bot/test_utils.py
new file mode 100644
index 000000000..58ae2a81a
--- /dev/null
+++ b/tests/bot/test_utils.py
@@ -0,0 +1,52 @@
+import unittest
+
+from bot import utils
+
+
+class CaseInsensitiveDictTests(unittest.TestCase):
+ """Tests for the `CaseInsensitiveDict` container."""
+
+ def test_case_insensitive_key_access(self):
+ """Tests case insensitive key access and storage."""
+ instance = utils.CaseInsensitiveDict()
+
+ key = 'LEMON'
+ value = 'trees'
+
+ instance[key] = value
+ self.assertIn(key, instance)
+ self.assertEqual(instance.get(key), value)
+ self.assertEqual(instance.get(key.casefold()), value)
+ self.assertEqual(instance.pop(key.casefold()), value)
+ self.assertNotIn(key, instance)
+ self.assertNotIn(key.casefold(), instance)
+
+ instance.setdefault(key, value)
+ del instance[key]
+ self.assertNotIn(key, instance)
+
+ def test_initialization_from_kwargs(self):
+ """Tests creating the dictionary from keyword arguments."""
+ instance = utils.CaseInsensitiveDict({'FOO': 'bar'})
+ self.assertEqual(instance['foo'], 'bar')
+
+ def test_update_from_other_mapping(self):
+ """Tests updating the dictionary from another mapping."""
+ instance = utils.CaseInsensitiveDict()
+ instance.update({'FOO': 'bar'})
+ self.assertEqual(instance['foo'], 'bar')
+
+
+class ChunkTests(unittest.TestCase):
+ """Tests the `chunk` method."""
+
+ def test_empty_chunking(self):
+ """Tests chunking on an empty iterable."""
+ generator = utils.chunks(iterable=[], size=5)
+ self.assertEqual(list(generator), [])
+
+ def test_list_chunking(self):
+ """Tests chunking a non-empty list."""
+ iterable = [1, 2, 3, 4, 5]
+ generator = utils.chunks(iterable=iterable, size=2)
+ self.assertEqual(list(generator), [[1, 2], [3, 4], [5]])
diff --git a/tests/bot/utils/__init__.py b/tests/bot/utils/__init__.py
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/tests/bot/utils/__init__.py
diff --git a/tests/bot/utils/test_checks.py b/tests/bot/utils/test_checks.py
new file mode 100644
index 000000000..19b758336
--- /dev/null
+++ b/tests/bot/utils/test_checks.py
@@ -0,0 +1,51 @@
+import unittest
+
+from bot.utils import checks
+from tests.helpers import MockContext, MockRole
+
+
+class ChecksTests(unittest.TestCase):
+ """Tests the check functions defined in `bot.checks`."""
+
+ def setUp(self):
+ self.ctx = MockContext()
+
+ def test_with_role_check_without_guild(self):
+ """`with_role_check` returns `False` if `Context.guild` is None."""
+ self.ctx.guild = None
+ self.assertFalse(checks.with_role_check(self.ctx))
+
+ def test_with_role_check_without_required_roles(self):
+ """`with_role_check` returns `False` if `Context.author` lacks the required role."""
+ self.ctx.author.roles = []
+ self.assertFalse(checks.with_role_check(self.ctx))
+
+ def test_with_role_check_with_guild_and_required_role(self):
+ """`with_role_check` returns `True` if `Context.author` has the required role."""
+ self.ctx.author.roles.append(MockRole(role_id=10))
+ self.assertTrue(checks.with_role_check(self.ctx, 10))
+
+ def test_without_role_check_without_guild(self):
+ """`without_role_check` should return `False` when `Context.guild` is None."""
+ self.ctx.guild = None
+ self.assertFalse(checks.without_role_check(self.ctx))
+
+ def test_without_role_check_returns_false_with_unwanted_role(self):
+ """`without_role_check` returns `False` if `Context.author` has unwanted role."""
+ role_id = 42
+ self.ctx.author.roles.append(MockRole(role_id=role_id))
+ self.assertFalse(checks.without_role_check(self.ctx, role_id))
+
+ def test_without_role_check_returns_true_without_unwanted_role(self):
+ """`without_role_check` returns `True` if `Context.author` does not have unwanted role."""
+ role_id = 42
+ self.ctx.author.roles.append(MockRole(role_id=role_id))
+ self.assertTrue(checks.without_role_check(self.ctx, role_id + 10))
+
+ def test_in_channel_check_for_correct_channel(self):
+ self.ctx.channel.id = 42
+ self.assertTrue(checks.in_channel_check(self.ctx, *[42]))
+
+ def test_in_channel_check_for_incorrect_channel(self):
+ self.ctx.channel.id = 42 + 10
+ self.assertFalse(checks.in_channel_check(self.ctx, *[42]))
diff --git a/tests/cogs/sync/test_roles.py b/tests/cogs/sync/test_roles.py
deleted file mode 100644
index c561ba447..000000000
--- a/tests/cogs/sync/test_roles.py
+++ /dev/null
@@ -1,103 +0,0 @@
-from bot.cogs.sync.syncers import Role, get_roles_for_sync
-
-
-def test_get_roles_for_sync_empty_return_for_equal_roles():
- api_roles = {Role(id=41, name='name', colour=33, permissions=0x8, position=1)}
- guild_roles = {Role(id=41, name='name', colour=33, permissions=0x8, position=1)}
-
- assert get_roles_for_sync(guild_roles, api_roles) == (set(), set(), set())
-
-
-def test_get_roles_for_sync_returns_roles_to_update_with_non_id_diff():
- api_roles = {Role(id=41, name='old name', colour=35, permissions=0x8, position=1)}
- guild_roles = {Role(id=41, name='new name', colour=33, permissions=0x8, position=2)}
-
- assert get_roles_for_sync(guild_roles, api_roles) == (
- set(),
- guild_roles,
- set(),
- )
-
-
-def test_get_roles_only_returns_roles_that_require_update():
- api_roles = {
- Role(id=41, name='old name', colour=33, permissions=0x8, position=1),
- Role(id=53, name='other role', colour=55, permissions=0, position=3)
- }
- guild_roles = {
- Role(id=41, name='new name', colour=35, permissions=0x8, position=2),
- Role(id=53, name='other role', colour=55, permissions=0, position=3)
- }
-
- assert get_roles_for_sync(guild_roles, api_roles) == (
- set(),
- {Role(id=41, name='new name', colour=35, permissions=0x8, position=2)},
- set(),
- )
-
-
-def test_get_roles_returns_new_roles_in_first_tuple_element():
- api_roles = {
- Role(id=41, name='name', colour=35, permissions=0x8, position=1),
- }
- guild_roles = {
- Role(id=41, name='name', colour=35, permissions=0x8, position=1),
- Role(id=53, name='other role', colour=55, permissions=0, position=2)
- }
-
- assert get_roles_for_sync(guild_roles, api_roles) == (
- {Role(id=53, name='other role', colour=55, permissions=0, position=2)},
- set(),
- set(),
- )
-
-
-def test_get_roles_returns_roles_to_update_and_new_roles():
- api_roles = {
- Role(id=41, name='old name', colour=35, permissions=0x8, position=1),
- }
- guild_roles = {
- Role(id=41, name='new name', colour=40, permissions=0x16, position=2),
- Role(id=53, name='other role', colour=55, permissions=0, position=3)
- }
-
- assert get_roles_for_sync(guild_roles, api_roles) == (
- {Role(id=53, name='other role', colour=55, permissions=0, position=3)},
- {Role(id=41, name='new name', colour=40, permissions=0x16, position=2)},
- set(),
- )
-
-
-def test_get_roles_returns_roles_to_delete():
- api_roles = {
- Role(id=41, name='name', colour=35, permissions=0x8, position=1),
- Role(id=61, name='to delete', colour=99, permissions=0x9, position=2),
- }
- guild_roles = {
- Role(id=41, name='name', colour=35, permissions=0x8, position=1),
- }
-
- assert get_roles_for_sync(guild_roles, api_roles) == (
- set(),
- set(),
- {Role(id=61, name='to delete', colour=99, permissions=0x9, position=2)},
- )
-
-
-def test_get_roles_returns_roles_to_delete_update_and_new_roles():
- api_roles = {
- Role(id=41, name='not changed', colour=35, permissions=0x8, position=1),
- Role(id=61, name='to delete', colour=99, permissions=0x9, position=2),
- Role(id=71, name='to update', colour=99, permissions=0x9, position=3),
- }
- guild_roles = {
- Role(id=41, name='not changed', colour=35, permissions=0x8, position=1),
- Role(id=81, name='to create', colour=99, permissions=0x9, position=4),
- Role(id=71, name='updated', colour=101, permissions=0x5, position=3),
- }
-
- assert get_roles_for_sync(guild_roles, api_roles) == (
- {Role(id=81, name='to create', colour=99, permissions=0x9, position=4)},
- {Role(id=71, name='updated', colour=101, permissions=0x5, position=3)},
- {Role(id=61, name='to delete', colour=99, permissions=0x9, position=2)},
- )
diff --git a/tests/cogs/sync/test_users.py b/tests/cogs/sync/test_users.py
deleted file mode 100644
index a863ae35b..000000000
--- a/tests/cogs/sync/test_users.py
+++ /dev/null
@@ -1,69 +0,0 @@
-from bot.cogs.sync.syncers import User, get_users_for_sync
-
-
-def fake_user(**kwargs):
- kwargs.setdefault('id', 43)
- kwargs.setdefault('name', 'bob the test man')
- kwargs.setdefault('discriminator', 1337)
- kwargs.setdefault('avatar_hash', None)
- kwargs.setdefault('roles', (666,))
- kwargs.setdefault('in_guild', True)
- return User(**kwargs)
-
-
-def test_get_users_for_sync_returns_nothing_for_empty_params():
- assert get_users_for_sync({}, {}) == (set(), set())
-
-
-def test_get_users_for_sync_returns_nothing_for_equal_users():
- api_users = {43: fake_user()}
- guild_users = {43: fake_user()}
-
- assert get_users_for_sync(guild_users, api_users) == (set(), set())
-
-
-def test_get_users_for_sync_returns_users_to_update_on_non_id_field_diff():
- api_users = {43: fake_user()}
- guild_users = {43: fake_user(name='new fancy name')}
-
- assert get_users_for_sync(guild_users, api_users) == (
- set(),
- {fake_user(name='new fancy name')}
- )
-
-
-def test_get_users_for_sync_returns_users_to_create_with_new_ids_on_guild():
- api_users = {43: fake_user()}
- guild_users = {43: fake_user(), 63: fake_user(id=63)}
-
- assert get_users_for_sync(guild_users, api_users) == (
- {fake_user(id=63)},
- set()
- )
-
-
-def test_get_users_for_sync_updates_in_guild_field_on_user_leave():
- api_users = {43: fake_user(), 63: fake_user(id=63)}
- guild_users = {43: fake_user()}
-
- assert get_users_for_sync(guild_users, api_users) == (
- set(),
- {fake_user(id=63, in_guild=False)}
- )
-
-
-def test_get_users_for_sync_updates_and_creates_users_as_needed():
- api_users = {43: fake_user()}
- guild_users = {63: fake_user(id=63)}
-
- assert get_users_for_sync(guild_users, api_users) == (
- {fake_user(id=63)},
- {fake_user(in_guild=False)}
- )
-
-
-def test_get_users_for_sync_does_not_duplicate_update_users():
- api_users = {43: fake_user(in_guild=False)}
- guild_users = {}
-
- assert get_users_for_sync(guild_users, api_users) == (set(), set())
diff --git a/tests/cogs/test_antispam.py b/tests/cogs/test_antispam.py
deleted file mode 100644
index 67900b275..000000000
--- a/tests/cogs/test_antispam.py
+++ /dev/null
@@ -1,30 +0,0 @@
-import pytest
-
-from bot.cogs import antispam
-
-
-def test_default_antispam_config_is_valid():
- validation_errors = antispam.validate_config()
- assert not validation_errors
-
-
- ('config', 'expected'),
- (
- (
- {'invalid-rule': {}},
- {'invalid-rule': "`invalid-rule` is not recognized as an antispam rule."}
- ),
- (
- {'burst': {'interval': 10}},
- {'burst': "Key `max` is required but not set for rule `burst`"}
- ),
- (
- {'burst': {'max': 10}},
- {'burst': "Key `interval` is required but not set for rule `burst`"}
- )
- )
-)
-def test_invalid_antispam_config_returns_validation_errors(config, expected):
- validation_errors = antispam.validate_config(config)
- assert validation_errors == expected
diff --git a/tests/cogs/test_information.py b/tests/cogs/test_information.py
deleted file mode 100644
index 85b2d092e..000000000
--- a/tests/cogs/test_information.py
+++ /dev/null
@@ -1,163 +0,0 @@
-import asyncio
-import logging
-import textwrap
-from datetime import datetime
-from unittest.mock import MagicMock, patch
-
-import pytest
-from discord import (
- CategoryChannel,
- Colour,
- TextChannel,
- VoiceChannel,
-)
-
-from bot.cogs import information
-from bot.constants import Emojis
-from bot.decorators import InChannelCheckFailure
-from tests.helpers import AsyncMock
-
-
-def cog(simple_bot):
- return information.Information(simple_bot)
-
-
-def role(name: str, id_: int):
- r = MagicMock()
- r.name = name
- r.id = id_
- r.mention = f'&{name}'
- return r
-
-
-def member(status: str):
- m = MagicMock()
- m.status = status
- return m
-
-
-def ctx(moderator_role, simple_ctx):
- simple_ctx.author.roles = [moderator_role]
- simple_ctx.guild.created_at = datetime(2001, 1, 1)
- simple_ctx.send = AsyncMock()
- return simple_ctx
-
-
-def test_roles_info_command(cog, ctx):
- everyone_role = MagicMock()
- everyone_role.name = '@everyone' # should be excluded in the output
- ctx.author.roles.append(everyone_role)
- ctx.guild.roles = ctx.author.roles
-
- cog.roles_info.can_run = AsyncMock()
- cog.roles_info.can_run.return_value = True
-
- coroutine = cog.roles_info.callback(cog, ctx)
-
- assert asyncio.run(coroutine) is None # no rval
- ctx.send.assert_called_once()
- _, kwargs = ctx.send.call_args
- embed = kwargs.pop('embed')
- assert embed.title == "Role information"
- assert embed.colour == Colour.blurple()
- assert embed.description == f"`{ctx.guild.roles[0].id}` - {ctx.guild.roles[0].mention}\n"
- assert embed.footer.text == "Total roles: 1"
-
-
-# There is no argument passed in here that we can use to test,
-# so the return value would change constantly.
-@patch('bot.cogs.information.time_since')
-def test_server_info_command(time_since_patch, cog, ctx, moderator_role):
- time_since_patch.return_value = '2 days ago'
-
- ctx.guild.created_at = datetime(2001, 1, 1)
- ctx.guild.features = ('lemons', 'apples')
- ctx.guild.region = 'The Moon'
- ctx.guild.roles = [moderator_role]
- ctx.guild.channels = [
- TextChannel(
- state={},
- guild=ctx.guild,
- data={'id': 42, 'name': 'lemons-offering', 'position': 22, 'type': 'text'}
- ),
- CategoryChannel(
- state={},
- guild=ctx.guild,
- data={'id': 5125, 'name': 'the-lemon-collection', 'position': 22, 'type': 'category'}
- ),
- VoiceChannel(
- state={},
- guild=ctx.guild,
- data={'id': 15290, 'name': 'listen-to-lemon', 'position': 22, 'type': 'voice'}
- )
- ]
- ctx.guild.members = [
- member('online'), member('online'),
- member('idle'),
- member('dnd'), member('dnd'), member('dnd'), member('dnd'),
- member('offline'), member('offline'), member('offline')
- ]
- ctx.guild.member_count = 1_234
- ctx.guild.icon_url = 'a-lemon.png'
-
- coroutine = cog.server_info.callback(cog, ctx)
- assert asyncio.run(coroutine) is None # no rval
-
- time_since_patch.assert_called_once_with(ctx.guild.created_at, precision='days')
- _, kwargs = ctx.send.call_args
- embed = kwargs.pop('embed')
- assert embed.colour == Colour.blurple()
- assert embed.description == textwrap.dedent(f"""
- **Server information**
- Created: {time_since_patch.return_value}
- Voice region: {ctx.guild.region}
- Features: {', '.join(ctx.guild.features)}
-
- **Counts**
- Members: {ctx.guild.member_count:,}
- Roles: {len(ctx.guild.roles)}
- Text: 1
- Voice: 1
- Channel categories: 1
-
- **Members**
- {Emojis.status_online} 2
- {Emojis.status_idle} 1
- {Emojis.status_dnd} 4
- {Emojis.status_offline} 3
- """)
- assert embed.thumbnail.url == 'a-lemon.png'
-
-
-def test_user_info_on_other_users_from_non_moderator(ctx, cog):
- ctx.author = MagicMock()
- ctx.author.__eq__.return_value = False
- ctx.author.roles = []
- coroutine = cog.user_info.callback(cog, ctx, user='scragly') # skip checks, pass args
-
- assert asyncio.run(coroutine) is None # no rval
- ctx.send.assert_called_once_with(
- "You may not use this command on users other than yourself."
- )
-
-
-def test_user_info_in_wrong_channel_from_non_moderator(ctx, cog):
- ctx.author = MagicMock()
- ctx.author.__eq__.return_value = False
- ctx.author.roles = []
-
- coroutine = cog.user_info.callback(cog, ctx)
- message = 'Sorry, but you may only use this command within <#267659945086812160>.'
- with pytest.raises(InChannelCheckFailure, match=message):
- assert asyncio.run(coroutine) is None # no rval
-
-
-def test_setup(simple_bot, caplog):
- information.setup(simple_bot)
- simple_bot.add_cog.assert_called_once()
- [record] = caplog.records
-
- assert record.message == "Cog loaded: Information"
- assert record.levelno == logging.INFO
diff --git a/tests/cogs/test_security.py b/tests/cogs/test_security.py
deleted file mode 100644
index 1efb460fe..000000000
--- a/tests/cogs/test_security.py
+++ /dev/null
@@ -1,54 +0,0 @@
-import logging
-from unittest.mock import MagicMock
-
-import pytest
-from discord.ext.commands import NoPrivateMessage
-
-from bot.cogs import security
-
-
-def cog():
- bot = MagicMock()
- return security.Security(bot)
-
-
-def context():
- return MagicMock()
-
-
-def test_check_additions(cog):
- cog.bot.check.assert_any_call(cog.check_on_guild)
- cog.bot.check.assert_any_call(cog.check_not_bot)
-
-
-def test_check_not_bot_for_humans(cog, context):
- context.author.bot = False
- assert cog.check_not_bot(context)
-
-
-def test_check_not_bot_for_robots(cog, context):
- context.author.bot = True
- assert not cog.check_not_bot(context)
-
-
-def test_check_on_guild_outside_of_guild(cog, context):
- context.guild = None
-
- with pytest.raises(NoPrivateMessage, match="This command cannot be used in private messages."):
- cog.check_on_guild(context)
-
-
-def test_check_on_guild_on_guild(cog, context):
- context.guild = "lemon's lemonade stand"
- assert cog.check_on_guild(context)
-
-
-def test_security_cog_load(caplog):
- bot = MagicMock()
- security.setup(bot)
- bot.add_cog.assert_called_once()
- [record] = caplog.records
- assert record.message == "Cog loaded: Security"
- assert record.levelno == logging.INFO
diff --git a/tests/cogs/test_token_remover.py b/tests/cogs/test_token_remover.py
deleted file mode 100644
index 9d46b3a05..000000000
--- a/tests/cogs/test_token_remover.py
+++ /dev/null
@@ -1,133 +0,0 @@
-import asyncio
-from unittest.mock import MagicMock
-
-import pytest
-from discord import Colour
-
-from bot.cogs.token_remover import (
- DELETION_MESSAGE_TEMPLATE,
- TokenRemover,
- setup as setup_cog,
-)
-from bot.constants import Channels, Colours, Event, Icons
-from tests.helpers import AsyncMock
-
-
-def token_remover():
- bot = MagicMock()
- bot.get_cog.return_value = MagicMock()
- bot.get_cog.return_value.send_log_message = AsyncMock()
- return TokenRemover(bot=bot)
-
-
-def message():
- message = MagicMock()
- message.author.__str__.return_value = 'lemon'
- message.author.bot = False
- message.author.avatar_url_as.return_value = 'picture-lemon.png'
- message.author.id = 42
- message.author.mention = '@lemon'
- message.channel.send = AsyncMock()
- message.channel.mention = '#lemonade-stand'
- message.content = ''
- message.delete = AsyncMock()
- message.id = 555
- return message
-
-
- ('content', 'expected'),
- (
- ('MTIz', True), # 123
- ('YWJj', False), # abc
- )
-)
-def test_is_valid_user_id(content: str, expected: bool):
- assert TokenRemover.is_valid_user_id(content) is expected
-
-
- ('content', 'expected'),
- (
- ('DN9r_A', True), # stolen from dapi, thanks to the author of the 'token' tag!
- ('MTIz', False), # 123
- )
-)
-def test_is_valid_timestamp(content: str, expected: bool):
- assert TokenRemover.is_valid_timestamp(content) is expected
-
-
-def test_mod_log_property(token_remover):
- token_remover.bot.get_cog.return_value = 'lemon'
- assert token_remover.mod_log == 'lemon'
- token_remover.bot.get_cog.assert_called_once_with('ModLog')
-
-
-def test_ignores_bot_messages(token_remover, message):
- message.author.bot = True
- coroutine = token_remover.on_message(message)
- assert asyncio.run(coroutine) is None
-
-
[email protected]('content', ('', 'lemon wins'))
-def test_ignores_messages_without_tokens(token_remover, message, content):
- message.content = content
- coroutine = token_remover.on_message(message)
- assert asyncio.run(coroutine) is None
-
-
[email protected]('content', ('foo.bar.baz', 'x.y.'))
-def test_ignores_invalid_tokens(token_remover, message, content):
- message.content = content
- coroutine = token_remover.on_message(message)
- assert asyncio.run(coroutine) is None
-
-
- 'content, censored_token',
- (
- ('MTIz.DN9R_A.xyz', 'MTIz.DN9R_A.xxx'),
- )
-)
-def test_censors_valid_tokens(
- token_remover, message, content, censored_token, caplog
-):
- message.content = content
- coroutine = token_remover.on_message(message)
- assert asyncio.run(coroutine) is None # still no rval
-
- # asyncio logs some stuff about its reactor, discard it
- [_, record] = caplog.records
- assert record.message == (
- "Censored a seemingly valid token sent by lemon (`42`) in #lemonade-stand, "
- f"token was `{censored_token}`"
- )
-
- message.delete.assert_called_once_with()
- message.channel.send.assert_called_once_with(
- DELETION_MESSAGE_TEMPLATE.format(mention='@lemon')
- )
- token_remover.bot.get_cog.assert_called_with('ModLog')
- message.author.avatar_url_as.assert_called_once_with(static_format='png')
-
- mod_log = token_remover.bot.get_cog.return_value
- mod_log.ignore.assert_called_once_with(Event.message_delete, message.id)
- mod_log.send_log_message.assert_called_once_with(
- icon_url=Icons.token_removed,
- colour=Colour(Colours.soft_red),
- title="Token removed!",
- text=record.message,
- thumbnail='picture-lemon.png',
- channel_id=Channels.mod_alerts
- )
-
-
-def test_setup(caplog):
- bot = MagicMock()
- setup_cog(bot)
- [record] = caplog.records
-
- bot.add_cog.assert_called_once()
- assert record.message == "Cog loaded: TokenRemover"
diff --git a/tests/conftest.py b/tests/conftest.py
deleted file mode 100644
index d3de4484d..000000000
--- a/tests/conftest.py
+++ /dev/null
@@ -1,32 +0,0 @@
-from unittest.mock import MagicMock
-
-import pytest
-
-from bot.constants import Roles
-from tests.helpers import AsyncMock
-
-
-def moderator_role():
- mock = MagicMock()
- mock.id = Roles.moderator
- mock.name = 'Moderator'
- mock.mention = f'&{mock.name}'
- return mock
-
-
-def simple_bot():
- mock = MagicMock()
- mock._before_invoke = AsyncMock()
- mock._after_invoke = AsyncMock()
- mock.can_run = AsyncMock()
- mock.can_run.return_value = True
- return mock
-
-
-def simple_ctx(simple_bot):
- mock = MagicMock()
- mock.bot = simple_bot
- return mock
diff --git a/tests/helpers.py b/tests/helpers.py
index 2908294f7..8496ba031 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -1,23 +1,19 @@
+from __future__ import annotations
+
import asyncio
import functools
-from unittest.mock import MagicMock
-
-
-__all__ = ('AsyncMock', 'async_test')
-
+import inspect
+import unittest.mock
+from typing import Any, Iterable, Optional
-# TODO: Remove me on 3.8
-class AsyncMock(MagicMock):
- async def __call__(self, *args, **kwargs):
- return super(AsyncMock, self).__call__(*args, **kwargs)
+import discord
+from discord.ext.commands import Bot, Context
def async_test(wrapped):
"""
Run a test case via asyncio.
-
Example:
-
>>> @async_test
... async def lemon_wins():
... assert True
@@ -27,3 +23,369 @@ def async_test(wrapped):
def wrapper(*args, **kwargs):
return asyncio.run(wrapped(*args, **kwargs))
return wrapper
+
+
+class HashableMixin(discord.mixins.EqualityComparable):
+ """
+ Mixin that provides similar hashing and equality functionality as discord.py's `Hashable` mixin.
+
+ Note: discord.py`s `Hashable` mixin bit-shifts `self.id` (`>> 22`); to prevent hash-collisions
+ for the relative small `id` integers we generally use in tests, this bit-shift is omitted.
+ """
+
+ def __hash__(self):
+ return self.id
+
+
+class ColourMixin:
+ """A mixin for Mocks that provides the aliasing of color->colour like discord.py does."""
+
+ @property
+ def color(self) -> discord.Colour:
+ return self.colour
+
+ @color.setter
+ def color(self, color: discord.Colour) -> None:
+ self.colour = color
+
+
+class CustomMockMixin:
+ """
+ Provides common functionality for our custom Mock types.
+
+ The cooperative `__init__` automatically creates `AsyncMock` attributes for every coroutine
+ function `inspect` detects in the `spec` instance we provide. In addition, this mixin takes care
+ of making sure child mocks are instantiated with the correct class. By default, the mock of the
+ children will be `unittest.mock.MagicMock`, but this can be overwritten by setting the attribute
+ `child_mock_type` on the custom mock inheriting from this mixin.
+ """
+
+ child_mock_type = unittest.mock.MagicMock
+
+ def __init__(self, spec: Any = None, **kwargs):
+ super().__init__(spec=spec, **kwargs)
+ if spec:
+ self._extract_coroutine_methods_from_spec_instance(spec)
+
+ def _get_child_mock(self, **kw):
+ """
+ Overwrite of the `_get_child_mock` method to stop the propagation of our custom mock classes.
+
+ Mock objects automatically create children when you access an attribute or call a method on them. By default,
+ the class of these children is the type of the parent itself. However, this would mean that the children created
+ for our custom mock types would also be instances of that custom mock type. This is not desirable, as attributes
+ of, e.g., a `Bot` object are not `Bot` objects themselves. The Python docs for `unittest.mock` hint that
+ overwriting this method is the best way to deal with that.
+
+ This override will look for an attribute called `child_mock_type` and use that as the type of the child mock.
+ """
+ klass = self.child_mock_type
+
+ if self._mock_sealed:
+ attribute = "." + kw["name"] if "name" in kw else "()"
+ mock_name = self._extract_mock_name() + attribute
+ raise AttributeError(mock_name)
+
+ return klass(**kw)
+
+ def _extract_coroutine_methods_from_spec_instance(self, source: Any) -> None:
+ """Automatically detect coroutine functions in `source` and set them as AsyncMock attributes."""
+ for name, _method in inspect.getmembers(source, inspect.iscoroutinefunction):
+ setattr(self, name, AsyncMock())
+
+
+# TODO: Remove me in Python 3.8
+class AsyncMock(CustomMockMixin, unittest.mock.MagicMock):
+ """
+ A MagicMock subclass to mock async callables.
+
+ Python 3.8 will introduce an AsyncMock class in the standard library that will have some more
+ features; this stand-in only overwrites the `__call__` method to an async version.
+ """
+ async def __call__(self, *args, **kwargs):
+ return super(AsyncMock, self).__call__(*args, **kwargs)
+
+
+# Create a guild instance to get a realistic Mock of `discord.Guild`
+guild_data = {
+ 'id': 1,
+ 'name': 'guild',
+ 'region': 'Europe',
+ 'verification_level': 2,
+ 'default_notications': 1,
+ 'afk_timeout': 100,
+ 'icon': "icon.png",
+ 'banner': 'banner.png',
+ 'mfa_level': 1,
+ 'splash': 'splash.png',
+ 'system_channel_id': 464033278631084042,
+ 'description': 'mocking is fun',
+ 'max_presences': 10_000,
+ 'max_members': 100_000,
+ 'preferred_locale': 'UTC',
+ 'owner_id': 1,
+ 'afk_channel_id': 464033278631084042,
+}
+guild_instance = discord.Guild(data=guild_data, state=unittest.mock.MagicMock())
+
+
+class MockGuild(CustomMockMixin, unittest.mock.Mock, HashableMixin):
+ """
+ A `Mock` subclass to mock `discord.Guild` objects.
+
+ A MockGuild instance will follow the specifications of a `discord.Guild` instance. This means
+ that if the code you're testing tries to access an attribute or method that normally does not
+ exist for a `discord.Guild` object this will raise an `AttributeError`. This is to make sure our
+ tests fail if the code we're testing uses a `discord.Guild` object in the wrong way.
+
+ One restriction of that is that if the code tries to access an attribute that normally does not
+ exist for `discord.Guild` instance but was added dynamically, this will raise an exception with
+ the mocked object. To get around that, you can set the non-standard attribute explicitly for the
+ instance of `MockGuild`:
+
+ >>> guild = MockGuild()
+ >>> guild.attribute_that_normally_does_not_exist = unittest.mock.MagicMock()
+
+ In addition to attribute simulation, mocked guild object will pass an `isinstance` check against
+ `discord.Guild`:
+
+ >>> guild = MockGuild()
+ >>> isinstance(guild, discord.Guild)
+ True
+
+ For more info, see the `Mocking` section in `tests/README.md`.
+ """
+ def __init__(
+ self,
+ guild_id: int = 1,
+ roles: Optional[Iterable[MockRole]] = None,
+ members: Optional[Iterable[MockMember]] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(spec=guild_instance, **kwargs)
+
+ self.id = guild_id
+
+ self.roles = [MockRole("@everyone", 1)]
+ if roles:
+ self.roles.extend(roles)
+
+ self.members = []
+ if members:
+ self.members.extend(members)
+
+
+# Create a Role instance to get a realistic Mock of `discord.Role`
+role_data = {'name': 'role', 'id': 1}
+role_instance = discord.Role(guild=guild_instance, state=unittest.mock.MagicMock(), data=role_data)
+
+
+class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):
+ """
+ A Mock subclass to mock `discord.Role` objects.
+
+ Instances of this class will follow the specifications of `discord.Role` instances. For more
+ information, see the `MockGuild` docstring.
+ """
+ def __init__(self, name: str = "role", role_id: int = 1, position: int = 1, **kwargs) -> None:
+ super().__init__(spec=role_instance, **kwargs)
+
+ self.name = name
+ self.id = role_id
+ self.position = position
+ self.mention = f'&{self.name}'
+
+ def __lt__(self, other):
+ """Simplified position-based comparisons similar to those of `discord.Role`."""
+ return self.position < other.position
+
+
+# Create a Member instance to get a realistic Mock of `discord.Member`
+member_data = {'user': 'lemon', 'roles': [1]}
+state_mock = unittest.mock.MagicMock()
+member_instance = discord.Member(data=member_data, guild=guild_instance, state=state_mock)
+
+
+class MockMember(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin):
+ """
+ A Mock subclass to mock Member objects.
+
+ Instances of this class will follow the specifications of `discord.Member` instances. For more
+ information, see the `MockGuild` docstring.
+ """
+ def __init__(
+ self,
+ name: str = "member",
+ user_id: int = 1,
+ roles: Optional[Iterable[MockRole]] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(spec=member_instance, **kwargs)
+
+ self.name = name
+ self.id = user_id
+
+ self.roles = [MockRole("@everyone", 1)]
+ if roles:
+ self.roles.extend(roles)
+
+ self.mention = f"@{self.name}"
+
+
+# Create a Bot instance to get a realistic MagicMock of `discord.ext.commands.Bot`
+bot_instance = Bot(command_prefix=unittest.mock.MagicMock())
+
+
+class MockBot(CustomMockMixin, unittest.mock.MagicMock):
+ """
+ A MagicMock subclass to mock Bot objects.
+
+ Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances.
+ For more information, see the `MockGuild` docstring.
+ """
+ def __init__(self, **kwargs) -> None:
+ super().__init__(spec=bot_instance, **kwargs)
+
+ # Our custom attributes and methods
+ self.http_session = unittest.mock.MagicMock()
+ self.api_client = unittest.mock.MagicMock()
+
+ # self.wait_for is *not* a coroutine function, but returns a coroutine nonetheless and
+ # and should therefore be awaited. (The documentation calls it a coroutine as well, which
+ # is technically incorrect, since it's a regular def.)
+ self.wait_for = AsyncMock()
+
+
+# Create a TextChannel instance to get a realistic MagicMock of `discord.TextChannel`
+channel_data = {
+ 'id': 1,
+ 'type': 'TextChannel',
+ 'name': 'channel',
+ 'parent_id': 1234567890,
+ 'topic': 'topic',
+ 'position': 1,
+ 'nsfw': False,
+ 'last_message_id': 1,
+}
+state = unittest.mock.MagicMock()
+guild = unittest.mock.MagicMock()
+channel_instance = discord.TextChannel(state=state, guild=guild, data=channel_data)
+
+
+class MockTextChannel(CustomMockMixin, unittest.mock.Mock, HashableMixin):
+ """
+ A MagicMock subclass to mock TextChannel objects.
+
+ Instances of this class will follow the specifications of `discord.TextChannel` instances. For
+ more information, see the `MockGuild` docstring.
+ """
+ def __init__(self, name: str = 'channel', channel_id: int = 1, **kwargs) -> None:
+ super().__init__(spec=channel_instance, **kwargs)
+ self.id = channel_id
+ self.name = name
+ self.guild = kwargs.get('guild', MockGuild())
+ self.mention = f"#{self.name}"
+
+
+# Create a Message instance to get a realistic MagicMock of `discord.Message`
+message_data = {
+ 'id': 1,
+ 'webhook_id': 431341013479718912,
+ 'attachments': [],
+ 'embeds': [],
+ 'application': 'Python Discord',
+ 'activity': 'mocking',
+ 'channel': unittest.mock.MagicMock(),
+ 'edited_timestamp': '2019-10-14T15:33:48+00:00',
+ 'type': 'message',
+ 'pinned': False,
+ 'mention_everyone': False,
+ 'tts': None,
+ 'content': 'content',
+ 'nonce': None,
+}
+state = unittest.mock.MagicMock()
+channel = unittest.mock.MagicMock()
+message_instance = discord.Message(state=state, channel=channel, data=message_data)
+
+
+# Create a Context instance to get a realistic MagicMock of `discord.ext.commands.Context`
+context_instance = Context(message=unittest.mock.MagicMock(), prefix=unittest.mock.MagicMock())
+
+
+class MockContext(CustomMockMixin, unittest.mock.MagicMock):
+ """
+ A MagicMock subclass to mock Context objects.
+
+ Instances of this class will follow the specifications of `discord.ext.commands.Context`
+ instances. For more information, see the `MockGuild` docstring.
+ """
+ def __init__(self, **kwargs) -> None:
+ super().__init__(spec=context_instance, **kwargs)
+ self.bot = kwargs.get('bot', MockBot())
+ self.guild = kwargs.get('guild', MockGuild())
+ self.author = kwargs.get('author', MockMember())
+ self.channel = kwargs.get('channel', MockTextChannel())
+ self.command = kwargs.get('command', unittest.mock.MagicMock())
+
+
+class MockMessage(CustomMockMixin, unittest.mock.MagicMock):
+ """
+ A MagicMock subclass to mock Message objects.
+
+ Instances of this class will follow the specifications of `discord.Message` instances. For more
+ information, see the `MockGuild` docstring.
+ """
+ def __init__(self, **kwargs) -> None:
+ super().__init__(spec=message_instance, **kwargs)
+ self.author = kwargs.get('author', MockMember())
+ self.channel = kwargs.get('channel', MockTextChannel())
+
+
+emoji_data = {'require_colons': True, 'managed': True, 'id': 1, 'name': 'hyperlemon'}
+emoji_instance = discord.Emoji(guild=MockGuild(), state=unittest.mock.MagicMock(), data=emoji_data)
+
+
+class MockEmoji(CustomMockMixin, unittest.mock.MagicMock):
+ """
+ A MagicMock subclass to mock Emoji objects.
+
+ Instances of this class will follow the specifications of `discord.Emoji` instances. For more
+ information, see the `MockGuild` docstring.
+ """
+ def __init__(self, **kwargs) -> None:
+ super().__init__(spec=emoji_instance, **kwargs)
+ self.guild = kwargs.get('guild', MockGuild())
+
+ # Get all coroutine functions and set them as AsyncMock attributes
+ self._extract_coroutine_methods_from_spec_instance(emoji_instance)
+
+
+partial_emoji_instance = discord.PartialEmoji(animated=False, name='guido')
+
+
+class MockPartialEmoji(CustomMockMixin, unittest.mock.MagicMock):
+ """
+ A MagicMock subclass to mock PartialEmoji objects.
+
+ Instances of this class will follow the specifications of `discord.PartialEmoji` instances. For
+ more information, see the `MockGuild` docstring.
+ """
+ def __init__(self, **kwargs) -> None:
+ super().__init__(spec=partial_emoji_instance, **kwargs)
+
+
+reaction_instance = discord.Reaction(message=MockMessage(), data={'me': True}, emoji=MockEmoji())
+
+
+class MockReaction(CustomMockMixin, unittest.mock.MagicMock):
+ """
+ A MagicMock subclass to mock Reaction objects.
+
+ Instances of this class will follow the specifications of `discord.Reaction` instances. For
+ more information, see the `MockGuild` docstring.
+ """
+ def __init__(self, **kwargs) -> None:
+ super().__init__(spec=reaction_instance, **kwargs)
+ self.emoji = kwargs.get('emoji', MockEmoji())
+ self.message = kwargs.get('message', MockMessage())
diff --git a/tests/rules/test_attachments.py b/tests/rules/test_attachments.py
deleted file mode 100644
index 6f025b3cb..000000000
--- a/tests/rules/test_attachments.py
+++ /dev/null
@@ -1,52 +0,0 @@
-import asyncio
-from dataclasses import dataclass
-from typing import Any, List
-
-import pytest
-
-from bot.rules import attachments
-
-
-# Using `MagicMock` sadly doesn't work for this usecase
-# since it's __eq__ compares the MagicMock's ID. We just
-# want to compare the actual attributes we set.
-@dataclass
-class FakeMessage:
- author: str
- attachments: List[Any]
-
-
-def msg(total_attachments: int):
- return FakeMessage(author='lemon', attachments=list(range(total_attachments)))
-
-
- 'messages',
- (
- (msg(0), msg(0), msg(0)),
- (msg(2), msg(2)),
- (msg(0),),
- )
-)
-def test_allows_messages_without_too_many_attachments(messages):
- last_message, *recent_messages = messages
- coro = attachments.apply(last_message, recent_messages, {'max': 5})
- assert asyncio.run(coro) is None
-
-
- ('messages', 'relevant_messages', 'total'),
- (
- ((msg(4), msg(0), msg(6)), [msg(4), msg(6)], 10),
- ((msg(6),), [msg(6)], 6),
- ((msg(1),) * 6, [msg(1)] * 6, 6),
- )
-)
-def test_disallows_messages_with_too_many_attachments(messages, relevant_messages, total):
- last_message, *recent_messages = messages
- coro = attachments.apply(last_message, recent_messages, {'max': 5})
- assert asyncio.run(coro) == (
- f"sent {total} attachments in 5s",
- ('lemon',),
- relevant_messages
- )
diff --git a/tests/test_api.py b/tests/test_api.py
deleted file mode 100644
index ce69ef187..000000000
--- a/tests/test_api.py
+++ /dev/null
@@ -1,106 +0,0 @@
-import logging
-from unittest.mock import MagicMock, patch
-
-import pytest
-
-from bot import api
-from tests.helpers import async_test
-
-
-def test_loop_is_not_running_by_default():
- assert not api.loop_is_running()
-
-
-@async_test
-async def test_loop_is_running_in_async_test():
- assert api.loop_is_running()
-
-
-def error_api_response():
- response = MagicMock()
- response.status = 999
- return response
-
-
-def api_log_handler():
- return api.APILoggingHandler(None)
-
-
-def debug_log_record():
- return logging.LogRecord(
- name='my.logger', level=logging.DEBUG,
- pathname='my/logger.py', lineno=666,
- msg="Lemon wins", args=(),
- exc_info=None
- )
-
-
-def test_response_code_error_default_initialization(error_api_response):
- error = api.ResponseCodeError(response=error_api_response)
- assert error.status is error_api_response.status
- assert not error.response_json
- assert not error.response_text
- assert error.response is error_api_response
-
-
-def test_response_code_error_default_representation(error_api_response):
- error = api.ResponseCodeError(response=error_api_response)
- assert str(error) == f"Status: {error_api_response.status} Response: "
-
-
-def test_response_code_error_representation_with_nonempty_response_json(error_api_response):
- error = api.ResponseCodeError(
- response=error_api_response,
- response_json={'hello': 'world'}
- )
- assert str(error) == f"Status: {error_api_response.status} Response: {{'hello': 'world'}}"
-
-
-def test_response_code_error_representation_with_nonempty_response_text(error_api_response):
- error = api.ResponseCodeError(
- response=error_api_response,
- response_text='Lemon will eat your soul'
- )
- assert str(error) == f"Status: {error_api_response.status} Response: Lemon will eat your soul"
-
-
-@patch('bot.api.APILoggingHandler.ship_off')
-def test_emit_appends_to_queue_with_stopped_event_loop(
- ship_off_patch, api_log_handler, debug_log_record
-):
- # This is a coroutine so returns something we should await,
- # but asyncio complains about that. To ease testing, we patch
- # `ship_off` to just return a regular value instead.
- ship_off_patch.return_value = 42
- api_log_handler.emit(debug_log_record)
-
- assert api_log_handler.queue == [42]
-
-
-def test_emit_ignores_less_than_debug(debug_log_record, api_log_handler):
- debug_log_record.levelno = logging.DEBUG - 5
- api_log_handler.emit(debug_log_record)
- assert not api_log_handler.queue
-
-
-def test_schedule_queued_tasks_for_empty_queue(api_log_handler, caplog):
- api_log_handler.schedule_queued_tasks()
- # Logs when tasks are scheduled
- assert not caplog.records
-
-
-@patch('asyncio.create_task')
-def test_schedule_queued_tasks_for_nonempty_queue(create_task_patch, api_log_handler, caplog):
- api_log_handler.queue = [555]
- api_log_handler.schedule_queued_tasks()
- assert not api_log_handler.queue
- create_task_patch.assert_called_once_with(555)
-
- [record] = caplog.records
- assert record.message == "Scheduled 1 pending logging tasks."
- assert record.levelno == logging.DEBUG
- assert record.name == 'bot.api'
- assert record.__dict__['via_handler']
diff --git a/tests/test_base.py b/tests/test_base.py
new file mode 100644
index 000000000..a16e2af8f
--- /dev/null
+++ b/tests/test_base.py
@@ -0,0 +1,91 @@
+import logging
+import unittest
+import unittest.mock
+
+
+from tests.base import LoggingTestCase, _CaptureLogHandler
+
+
+class LoggingTestCaseTests(unittest.TestCase):
+ """Tests for the LoggingTestCase."""
+
+ @classmethod
+ def setUpClass(cls):
+ cls.log = logging.getLogger(__name__)
+
+ def test_assert_not_logs_does_not_raise_with_no_logs(self):
+ """Test if LoggingTestCase.assertNotLogs does not raise when no logs were emitted."""
+ try:
+ with LoggingTestCase.assertNotLogs(self, level=logging.DEBUG):
+ pass
+ except AssertionError:
+ self.fail("`self.assertNotLogs` raised an AssertionError when it should not!")
+
+ @unittest.mock.patch("tests.base.LoggingTestCase.assertNotLogs")
+ def test_the_test_function_assert_not_logs_does_not_raise_with_no_logs(self, assertNotLogs):
+ """Test if test_assert_not_logs_does_not_raise_with_no_logs captures exception correctly."""
+ assertNotLogs.return_value = iter([None])
+ assertNotLogs.side_effect = AssertionError
+
+ message = "`self.assertNotLogs` raised an AssertionError when it should not!"
+ with self.assertRaises(AssertionError, msg=message):
+ self.test_assert_not_logs_does_not_raise_with_no_logs()
+
+ def test_assert_not_logs_raises_correct_assertion_error_when_logs_are_emitted(self):
+ """Test if LoggingTestCase.assertNotLogs raises AssertionError when logs were emitted."""
+ msg_regex = (
+ r"1 logs of DEBUG or higher were triggered on root:\n"
+ r'<LogRecord: tests\.test_base, [\d]+, .+/tests/test_base\.py, [\d]+, "Log!">'
+ )
+ with self.assertRaisesRegex(AssertionError, msg_regex):
+ with LoggingTestCase.assertNotLogs(self, level=logging.DEBUG):
+ self.log.debug("Log!")
+
+ def test_assert_not_logs_reraises_unexpected_exception_in_managed_context(self):
+ """Test if LoggingTestCase.assertNotLogs reraises an unexpected exception."""
+ with self.assertRaises(ValueError, msg="test exception"):
+ with LoggingTestCase.assertNotLogs(self, level=logging.DEBUG):
+ raise ValueError("test exception")
+
+ def test_assert_not_logs_restores_old_logging_settings(self):
+ """Test if LoggingTestCase.assertNotLogs reraises an unexpected exception."""
+ old_handlers = self.log.handlers[:]
+ old_level = self.log.level
+ old_propagate = self.log.propagate
+
+ with LoggingTestCase.assertNotLogs(self, level=logging.DEBUG):
+ pass
+
+ self.assertEqual(self.log.handlers, old_handlers)
+ self.assertEqual(self.log.level, old_level)
+ self.assertEqual(self.log.propagate, old_propagate)
+
+ def test_logging_test_case_works_with_logger_instance(self):
+ """Test if the LoggingTestCase captures logging for provided logger."""
+ log = logging.getLogger("new_logger")
+ with self.assertRaises(AssertionError):
+ with LoggingTestCase.assertNotLogs(self, logger=log):
+ log.info("Hello, this should raise an AssertionError")
+
+ def test_logging_test_case_respects_alternative_logger(self):
+ """Test if LoggingTestCase only checks the provided logger."""
+ log_one = logging.getLogger("log one")
+ log_two = logging.getLogger("log two")
+ with LoggingTestCase.assertNotLogs(self, logger=log_one):
+ log_two.info("Hello, this should not raise an AssertionError")
+
+ def test_logging_test_case_respects_logging_level(self):
+ """Test if LoggingTestCase does not raise for a logging level lower than provided."""
+ with LoggingTestCase.assertNotLogs(self, level=logging.CRITICAL):
+ self.log.info("Hello, this should raise an AssertionError")
+
+ def test_capture_log_handler_default_initialization(self):
+ """Test if the _CaptureLogHandler is initialized properly."""
+ handler = _CaptureLogHandler()
+ self.assertFalse(handler.records)
+
+ def test_capture_log_handler_saves_record_on_emit(self):
+ """Test if the _CaptureLogHandler saves the log record when it's emitted."""
+ handler = _CaptureLogHandler()
+ handler.emit("Log message")
+ self.assertIn("Log message", handler.records)
diff --git a/tests/test_constants.py b/tests/test_constants.py
deleted file mode 100644
index e4a29d994..000000000
--- a/tests/test_constants.py
+++ /dev/null
@@ -1,23 +0,0 @@
-import inspect
-
-import pytest
-
-from bot import constants
-
-
- 'section',
- (
- cls
- for (name, cls) in inspect.getmembers(constants)
- if hasattr(cls, 'section') and isinstance(cls, type)
- )
-)
-def test_section_configuration_matches_typespec(section):
- for (name, annotation) in section.__annotations__.items():
- value = getattr(section, name)
-
- if getattr(annotation, '_name', None) in ('Dict', 'List'):
- pytest.skip("Cannot validate containers yet")
-
- assert isinstance(value, annotation)
diff --git a/tests/test_converters.py b/tests/test_converters.py
deleted file mode 100644
index 35fc5d88e..000000000
--- a/tests/test_converters.py
+++ /dev/null
@@ -1,186 +0,0 @@
-import asyncio
-import datetime
-from unittest.mock import MagicMock, patch
-
-import pytest
-from dateutil.relativedelta import relativedelta
-from discord.ext.commands import BadArgument
-
-from bot.converters import (
- Duration,
- TagContentConverter,
- TagNameConverter,
- ValidPythonIdentifier,
-)
-
-
- ('value', 'expected'),
- (
- ('hello', 'hello'),
- (' h ello ', 'h ello')
- )
-)
-def test_tag_content_converter_for_valid(value: str, expected: str):
- assert asyncio.run(TagContentConverter.convert(None, value)) == expected
-
-
- ('value', 'expected'),
- (
- ('', "Tag contents should not be empty, or filled with whitespace."),
- (' ', "Tag contents should not be empty, or filled with whitespace.")
- )
-)
-def test_tag_content_converter_for_invalid(value: str, expected: str):
- context = MagicMock()
- context.author = 'bob'
-
- with pytest.raises(BadArgument, match=expected):
- asyncio.run(TagContentConverter.convert(context, value))
-
-
- ('value', 'expected'),
- (
- ('tracebacks', 'tracebacks'),
- ('Tracebacks', 'tracebacks'),
- (' Tracebacks ', 'tracebacks'),
- )
-)
-def test_tag_name_converter_for_valid(value: str, expected: str):
- assert asyncio.run(TagNameConverter.convert(None, value)) == expected
-
-
- ('value', 'expected'),
- (
- ('👋', "Don't be ridiculous, you can't use that character!"),
- ('', "Tag names should not be empty, or filled with whitespace."),
- (' ', "Tag names should not be empty, or filled with whitespace."),
- ('42', "Tag names can't be numbers."),
- # Escape question mark as this is evaluated as regular expression.
- ('x' * 128, r"Are you insane\? That's way too long!"),
- )
-)
-def test_tag_name_converter_for_invalid(value: str, expected: str):
- context = MagicMock()
- context.author = 'bob'
-
- with pytest.raises(BadArgument, match=expected):
- asyncio.run(TagNameConverter.convert(context, value))
-
-
[email protected]('value', ('foo', 'lemon'))
-def test_valid_python_identifier_for_valid(value: str):
- assert asyncio.run(ValidPythonIdentifier.convert(None, value)) == value
-
-
[email protected]('value', ('nested.stuff', '#####'))
-def test_valid_python_identifier_for_invalid(value: str):
- with pytest.raises(BadArgument, match=f'`{value}` is not a valid Python identifier'):
- asyncio.run(ValidPythonIdentifier.convert(None, value))
-
-
-FIXED_UTC_NOW = datetime.datetime.fromisoformat('2019-01-01T00:00:00')
-
-
- params=(
- # Simple duration strings
- ('1Y', {"years": 1}),
- ('1y', {"years": 1}),
- ('1year', {"years": 1}),
- ('1years', {"years": 1}),
- ('1m', {"months": 1}),
- ('1month', {"months": 1}),
- ('1months', {"months": 1}),
- ('1w', {"weeks": 1}),
- ('1W', {"weeks": 1}),
- ('1week', {"weeks": 1}),
- ('1weeks', {"weeks": 1}),
- ('1d', {"days": 1}),
- ('1D', {"days": 1}),
- ('1day', {"days": 1}),
- ('1days', {"days": 1}),
- ('1h', {"hours": 1}),
- ('1H', {"hours": 1}),
- ('1hour', {"hours": 1}),
- ('1hours', {"hours": 1}),
- ('1M', {"minutes": 1}),
- ('1minute', {"minutes": 1}),
- ('1minutes', {"minutes": 1}),
- ('1s', {"seconds": 1}),
- ('1S', {"seconds": 1}),
- ('1second', {"seconds": 1}),
- ('1seconds', {"seconds": 1}),
-
- # Complex duration strings
- (
- '1y1m1w1d1H1M1S',
- {
- "years": 1,
- "months": 1,
- "weeks": 1,
- "days": 1,
- "hours": 1,
- "minutes": 1,
- "seconds": 1
- }
- ),
- ('5y100S', {"years": 5, "seconds": 100}),
- ('2w28H', {"weeks": 2, "hours": 28}),
-
- # Duration strings with spaces
- ('1 year 2 months', {"years": 1, "months": 2}),
- ('1d 2H', {"days": 1, "hours": 2}),
- ('1 week2 days', {"weeks": 1, "days": 2}),
- )
-)
-def create_future_datetime(request):
- """Yields duration string and target datetime.datetime object."""
- duration, duration_dict = request.param
- future_datetime = FIXED_UTC_NOW + relativedelta(**duration_dict)
- yield duration, future_datetime
-
-
-def test_duration_converter_for_valid(create_future_datetime: tuple):
- converter = Duration()
- duration, expected = create_future_datetime
- with patch('bot.converters.datetime') as mock_datetime:
- mock_datetime.utcnow.return_value = FIXED_UTC_NOW
- assert asyncio.run(converter.convert(None, duration)) == expected
-
-
- ('duration'),
- (
- # Units in wrong order
- ('1d1w'),
- ('1s1y'),
-
- # Duplicated units
- ('1 year 2 years'),
- ('1 M 10 minutes'),
-
- # Unknown substrings
- ('1MVes'),
- ('1y3breads'),
-
- # Missing amount
- ('ym'),
-
- # Incorrect whitespace
- (" 1y"),
- ("1S "),
- ("1y 1m"),
-
- # Garbage
- ('Guido van Rossum'),
- ('lemon lemon lemon lemon lemon lemon lemon'),
- )
-)
-def test_duration_converter_for_invalid(duration: str):
- converter = Duration()
- with pytest.raises(BadArgument, match=f'`{duration}` is not a valid duration string.'):
- asyncio.run(converter.convert(None, duration))
diff --git a/tests/test_helpers.py b/tests/test_helpers.py
new file mode 100644
index 000000000..2b58634dd
--- /dev/null
+++ b/tests/test_helpers.py
@@ -0,0 +1,428 @@
+import asyncio
+import inspect
+import unittest
+import unittest.mock
+
+import discord
+
+from tests import helpers
+
+
+class DiscordMocksTests(unittest.TestCase):
+ """Tests for our specialized discord.py mocks."""
+
+ def test_mock_role_default_initialization(self):
+ """Test if the default initialization of MockRole results in the correct object."""
+ role = helpers.MockRole()
+
+ # The `spec` argument makes sure `isistance` checks with `discord.Role` pass
+ self.assertIsInstance(role, discord.Role)
+
+ self.assertEqual(role.name, "role")
+ self.assertEqual(role.id, 1)
+ self.assertEqual(role.position, 1)
+ self.assertEqual(role.mention, "&role")
+
+ def test_mock_role_alternative_arguments(self):
+ """Test if MockRole initializes with the arguments provided."""
+ role = helpers.MockRole(
+ name="Admins",
+ role_id=90210,
+ position=10,
+ )
+
+ self.assertEqual(role.name, "Admins")
+ self.assertEqual(role.id, 90210)
+ self.assertEqual(role.position, 10)
+ self.assertEqual(role.mention, "&Admins")
+
+ def test_mock_role_accepts_dynamic_arguments(self):
+ """Test if MockRole accepts and sets abitrary keyword arguments."""
+ role = helpers.MockRole(
+ guild="Dino Man",
+ hoist=True,
+ )
+
+ self.assertEqual(role.guild, "Dino Man")
+ self.assertTrue(role.hoist)
+
+ def test_mock_role_uses_position_for_less_than_greater_than(self):
+ """Test if `<` and `>` comparisons for MockRole are based on its position attribute."""
+ role_one = helpers.MockRole(position=1)
+ role_two = helpers.MockRole(position=2)
+ role_three = helpers.MockRole(position=3)
+
+ self.assertLess(role_one, role_two)
+ self.assertLess(role_one, role_three)
+ self.assertLess(role_two, role_three)
+ self.assertGreater(role_three, role_two)
+ self.assertGreater(role_three, role_one)
+ self.assertGreater(role_two, role_one)
+
+ def test_mock_member_default_initialization(self):
+ """Test if the default initialization of Mockmember results in the correct object."""
+ member = helpers.MockMember()
+
+ # The `spec` argument makes sure `isistance` checks with `discord.Member` pass
+ self.assertIsInstance(member, discord.Member)
+
+ self.assertEqual(member.name, "member")
+ self.assertEqual(member.id, 1)
+ self.assertListEqual(member.roles, [helpers.MockRole("@everyone", 1)])
+ self.assertEqual(member.mention, "@member")
+
+ def test_mock_member_alternative_arguments(self):
+ """Test if MockMember initializes with the arguments provided."""
+ core_developer = helpers.MockRole("Core Developer", 2)
+ member = helpers.MockMember(
+ name="Mark",
+ user_id=12345,
+ roles=[core_developer]
+ )
+
+ self.assertEqual(member.name, "Mark")
+ self.assertEqual(member.id, 12345)
+ self.assertListEqual(member.roles, [helpers.MockRole("@everyone", 1), core_developer])
+ self.assertEqual(member.mention, "@Mark")
+
+ def test_mock_member_accepts_dynamic_arguments(self):
+ """Test if MockMember accepts and sets abitrary keyword arguments."""
+ member = helpers.MockMember(
+ nick="Dino Man",
+ colour=discord.Colour.default(),
+ )
+
+ self.assertEqual(member.nick, "Dino Man")
+ self.assertEqual(member.colour, discord.Colour.default())
+
+ def test_mock_guild_default_initialization(self):
+ """Test if the default initialization of Mockguild results in the correct object."""
+ guild = helpers.MockGuild()
+
+ # The `spec` argument makes sure `isistance` checks with `discord.Guild` pass
+ self.assertIsInstance(guild, discord.Guild)
+
+ self.assertListEqual(guild.roles, [helpers.MockRole("@everyone", 1)])
+ self.assertListEqual(guild.members, [])
+
+ def test_mock_guild_alternative_arguments(self):
+ """Test if MockGuild initializes with the arguments provided."""
+ core_developer = helpers.MockRole("Core Developer", 2)
+ guild = helpers.MockGuild(
+ roles=[core_developer],
+ members=[helpers.MockMember(user_id=54321)],
+ )
+
+ self.assertListEqual(guild.roles, [helpers.MockRole("@everyone", 1), core_developer])
+ self.assertListEqual(guild.members, [helpers.MockMember(user_id=54321)])
+
+ def test_mock_guild_accepts_dynamic_arguments(self):
+ """Test if MockGuild accepts and sets abitrary keyword arguments."""
+ guild = helpers.MockGuild(
+ emojis=(":hyperjoseph:", ":pensive_ela:"),
+ premium_subscription_count=15,
+ )
+
+ self.assertTupleEqual(guild.emojis, (":hyperjoseph:", ":pensive_ela:"))
+ self.assertEqual(guild.premium_subscription_count, 15)
+
+ def test_mock_bot_default_initialization(self):
+ """Tests if MockBot initializes with the correct values."""
+ bot = helpers.MockBot()
+
+ # The `spec` argument makes sure `isistance` checks with `discord.ext.commands.Bot` pass
+ self.assertIsInstance(bot, discord.ext.commands.Bot)
+
+ def test_mock_context_default_initialization(self):
+ """Tests if MockContext initializes with the correct values."""
+ context = helpers.MockContext()
+
+ # The `spec` argument makes sure `isistance` checks with `discord.ext.commands.Context` pass
+ self.assertIsInstance(context, discord.ext.commands.Context)
+
+ self.assertIsInstance(context.bot, helpers.MockBot)
+ self.assertIsInstance(context.guild, helpers.MockGuild)
+ self.assertIsInstance(context.author, helpers.MockMember)
+
+ def test_mocks_allows_access_to_attributes_part_of_spec(self):
+ """Accessing attributes that are valid for the objects they mock should succeed."""
+ mocks = (
+ (helpers.MockGuild(), 'name'),
+ (helpers.MockRole(), 'hoist'),
+ (helpers.MockMember(), 'display_name'),
+ (helpers.MockBot(), 'user'),
+ (helpers.MockContext(), 'invoked_with'),
+ (helpers.MockTextChannel(), 'last_message'),
+ (helpers.MockMessage(), 'mention_everyone'),
+ )
+
+ for mock, valid_attribute in mocks:
+ with self.subTest(mock=mock):
+ try:
+ getattr(mock, valid_attribute)
+ except AttributeError:
+ msg = f"accessing valid attribute `{valid_attribute}` raised an AttributeError"
+ self.fail(msg)
+
+ @unittest.mock.patch(f'{__name__}.DiscordMocksTests.subTest')
+ @unittest.mock.patch(f'{__name__}.getattr')
+ def test_mock_allows_access_to_attributes_test(self, mock_getattr, mock_subtest):
+ """The valid attribute test should raise an AssertionError after an AttributeError."""
+ mock_getattr.side_effect = AttributeError
+
+ msg = "accessing valid attribute `name` raised an AttributeError"
+ with self.assertRaises(AssertionError, msg=msg):
+ self.test_mocks_allows_access_to_attributes_part_of_spec()
+
+ def test_mocks_rejects_access_to_attributes_not_part_of_spec(self):
+ """Accessing attributes that are invalid for the objects they mock should fail."""
+ mocks = (
+ helpers.MockGuild(),
+ helpers.MockRole(),
+ helpers.MockMember(),
+ helpers.MockBot(),
+ helpers.MockContext(),
+ helpers.MockTextChannel(),
+ helpers.MockMessage(),
+ )
+
+ for mock in mocks:
+ with self.subTest(mock=mock):
+ with self.assertRaises(AttributeError):
+ mock.the_cake_is_a_lie
+
+ def test_custom_mock_methods_are_valid_discord_object_methods(self):
+ """The `AsyncMock` attributes of the mocks should be valid for the class they're mocking."""
+ mocks = (
+ (helpers.MockGuild, helpers.guild_instance),
+ (helpers.MockRole, helpers.role_instance),
+ (helpers.MockMember, helpers.member_instance),
+ (helpers.MockBot, helpers.bot_instance),
+ (helpers.MockContext, helpers.context_instance),
+ (helpers.MockTextChannel, helpers.channel_instance),
+ (helpers.MockMessage, helpers.message_instance),
+ )
+
+ for mock_class, instance in mocks:
+ mock = mock_class()
+ async_methods = (
+ attr for attr in dir(mock) if isinstance(getattr(mock, attr), helpers.AsyncMock)
+ )
+
+ # spec_mock = unittest.mock.MagicMock(spec=instance)
+ for method in async_methods:
+ with self.subTest(mock_class=mock_class, method=method):
+ try:
+ getattr(instance, method)
+ except AttributeError:
+ msg = f"method {method} is not a method attribute of {instance.__class__}"
+ self.fail(msg)
+
+ @unittest.mock.patch(f'{__name__}.DiscordMocksTests.subTest')
+ def test_the_custom_mock_methods_test(self, subtest_mock):
+ """The custom method test should raise AssertionError for invalid methods."""
+ class FakeMockBot(helpers.CustomMockMixin, unittest.mock.MagicMock):
+ """Fake MockBot class with invalid attribute/method `release_the_walrus`."""
+
+ child_mock_type = unittest.mock.MagicMock
+
+ def __init__(self, **kwargs):
+ super().__init__(spec=helpers.bot_instance, **kwargs)
+
+ # Fake attribute
+ self.release_the_walrus = helpers.AsyncMock()
+
+ with unittest.mock.patch("tests.helpers.MockBot", new=FakeMockBot):
+ msg = "method release_the_walrus is not a valid method of <class 'discord.ext.commands.bot.Bot'>"
+ with self.assertRaises(AssertionError, msg=msg):
+ self.test_custom_mock_methods_are_valid_discord_object_methods()
+
+
+class MockObjectTests(unittest.TestCase):
+ """Tests the mock objects and mixins we've defined."""
+
+ @classmethod
+ def setUpClass(cls):
+ cls.hashable_mocks = (helpers.MockRole, helpers.MockMember, helpers.MockGuild)
+
+ def test_colour_mixin(self):
+ """Test if the ColourMixin adds aliasing of color -> colour for child classes."""
+ class MockHemlock(unittest.mock.MagicMock, helpers.ColourMixin):
+ pass
+
+ hemlock = MockHemlock()
+ hemlock.color = 1
+ self.assertEqual(hemlock.colour, 1)
+ self.assertEqual(hemlock.colour, hemlock.color)
+
+ def test_hashable_mixin_hash_returns_id(self):
+ """Test if the HashableMixing uses the id attribute for hashing."""
+ class MockScragly(unittest.mock.Mock, helpers.HashableMixin):
+ pass
+
+ scragly = MockScragly()
+ scragly.id = 10
+ self.assertEqual(hash(scragly), scragly.id)
+
+ def test_hashable_mixin_uses_id_for_equality_comparison(self):
+ """Test if the HashableMixing uses the id attribute for hashing."""
+ class MockScragly(unittest.mock.Mock, helpers.HashableMixin):
+ pass
+
+ scragly = MockScragly(spec=object)
+ scragly.id = 10
+ eevee = MockScragly(spec=object)
+ eevee.id = 10
+ python = MockScragly(spec=object)
+ python.id = 20
+
+ self.assertTrue(scragly == eevee)
+ self.assertFalse(scragly == python)
+
+ def test_hashable_mixin_uses_id_for_nonequality_comparison(self):
+ """Test if the HashableMixing uses the id attribute for hashing."""
+ class MockScragly(unittest.mock.Mock, helpers.HashableMixin):
+ pass
+
+ scragly = MockScragly(spec=object)
+ scragly.id = 10
+ eevee = MockScragly(spec=object)
+ eevee.id = 10
+ python = MockScragly(spec=object)
+ python.id = 20
+
+ self.assertTrue(scragly != python)
+ self.assertFalse(scragly != eevee)
+
+ def test_mock_class_with_hashable_mixin_uses_id_for_hashing(self):
+ """Test if the MagicMock subclasses that implement the HashableMixin use id for hash."""
+ for mock in self.hashable_mocks:
+ with self.subTest(mock_class=mock):
+ instance = helpers.MockRole(role_id=100)
+ self.assertEqual(hash(instance), instance.id)
+
+ def test_mock_class_with_hashable_mixin_uses_id_for_equality(self):
+ """Test if MagicMocks that implement the HashableMixin use id for equality comparisons."""
+ for mock_class in self.hashable_mocks:
+ with self.subTest(mock_class=mock_class):
+ instance_one = mock_class()
+ instance_two = mock_class()
+ instance_three = mock_class()
+
+ instance_one.id = 10
+ instance_two.id = 10
+ instance_three.id = 20
+
+ self.assertTrue(instance_one == instance_two)
+ self.assertFalse(instance_one == instance_three)
+
+ def test_mock_class_with_hashable_mixin_uses_id_for_nonequality(self):
+ """Test if MagicMocks that implement HashableMixin use id for nonequality comparisons."""
+ for mock_class in self.hashable_mocks:
+ with self.subTest(mock_class=mock_class):
+ instance_one = mock_class()
+ instance_two = mock_class()
+ instance_three = mock_class()
+
+ instance_one.id = 10
+ instance_two.id = 10
+ instance_three.id = 20
+
+ self.assertFalse(instance_one != instance_two)
+ self.assertTrue(instance_one != instance_three)
+
+ def test_custom_mock_mixin_accepts_mock_seal(self):
+ """The `CustomMockMixin` should support `unittest.mock.seal`."""
+ class MyMock(helpers.CustomMockMixin, unittest.mock.MagicMock):
+
+ child_mock_type = unittest.mock.MagicMock
+ pass
+
+ mock = MyMock()
+ unittest.mock.seal(mock)
+ with self.assertRaises(AttributeError, msg="MyMock.shirayuki"):
+ mock.shirayuki = "hello!"
+
+ def test_spec_propagation_of_mock_subclasses(self):
+ """Test if the `spec` does not propagate to attributes of the mock object."""
+ test_values = (
+ (helpers.MockGuild, "region"),
+ (helpers.MockRole, "mentionable"),
+ (helpers.MockMember, "display_name"),
+ (helpers.MockBot, "owner_id"),
+ (helpers.MockContext, "command_failed"),
+ (helpers.MockMessage, "mention_everyone"),
+ (helpers.MockEmoji, 'managed'),
+ (helpers.MockPartialEmoji, 'url'),
+ (helpers.MockReaction, 'me'),
+ )
+
+ for mock_type, valid_attribute in test_values:
+ with self.subTest(mock_type=mock_type, attribute=valid_attribute):
+ mock = mock_type()
+ self.assertTrue(isinstance(mock, mock_type))
+ attribute = getattr(mock, valid_attribute)
+ self.assertTrue(isinstance(attribute, mock_type.child_mock_type))
+
+ def test_extract_coroutine_methods_from_spec_instance_should_extract_all_and_only_coroutines(self):
+ """Test if all coroutine functions are extracted, but not regular methods or attributes."""
+ class CoroutineDonor:
+ def __init__(self):
+ self.some_attribute = 'alpha'
+
+ async def first_coroutine():
+ """This coroutine function should be extracted."""
+
+ async def second_coroutine():
+ """This coroutine function should be extracted."""
+
+ def regular_method():
+ """This regular function should not be extracted."""
+
+ class Receiver:
+ pass
+
+ donor = CoroutineDonor()
+ receiver = Receiver()
+
+ helpers.CustomMockMixin._extract_coroutine_methods_from_spec_instance(receiver, donor)
+
+ self.assertIsInstance(receiver.first_coroutine, helpers.AsyncMock)
+ self.assertIsInstance(receiver.second_coroutine, helpers.AsyncMock)
+ self.assertFalse(hasattr(receiver, 'regular_method'))
+ self.assertFalse(hasattr(receiver, 'some_attribute'))
+
+ @unittest.mock.patch("builtins.super", new=unittest.mock.MagicMock())
+ @unittest.mock.patch("tests.helpers.CustomMockMixin._extract_coroutine_methods_from_spec_instance")
+ def test_custom_mock_mixin_init_with_spec(self, extract_method_mock):
+ """Test if CustomMockMixin correctly passes on spec/kwargs and calls the extraction method."""
+ spec = "pydis"
+
+ helpers.CustomMockMixin(spec=spec)
+
+ extract_method_mock.assert_called_once_with(spec)
+
+ @unittest.mock.patch("builtins.super", new=unittest.mock.MagicMock())
+ @unittest.mock.patch("tests.helpers.CustomMockMixin._extract_coroutine_methods_from_spec_instance")
+ def test_custom_mock_mixin_init_without_spec(self, extract_method_mock):
+ """Test if CustomMockMixin correctly passes on spec/kwargs and calls the extraction method."""
+ helpers.CustomMockMixin()
+
+ extract_method_mock.assert_not_called()
+
+ def test_async_mock_provides_coroutine_for_dunder_call(self):
+ """Test if AsyncMock objects have a coroutine for their __call__ method."""
+ async_mock = helpers.AsyncMock()
+ self.assertTrue(inspect.iscoroutinefunction(async_mock.__call__))
+
+ coroutine = async_mock()
+ self.assertTrue(inspect.iscoroutine(coroutine))
+ self.assertIsNotNone(asyncio.run(coroutine))
+
+ def test_async_test_decorator_allows_synchronous_call_to_async_def(self):
+ """Test if the `async_test` decorator allows an `async def` to be called synchronously."""
+ @helpers.async_test
+ async def kosayoda():
+ return "return value"
+
+ self.assertEqual(kosayoda(), "return value")
diff --git a/tests/test_resources.py b/tests/test_resources.py
deleted file mode 100644
index 2b17aea64..000000000
--- a/tests/test_resources.py
+++ /dev/null
@@ -1,18 +0,0 @@
-import json
-import mimetypes
-from pathlib import Path
-from urllib.parse import urlparse
-
-
-def test_stars_valid():
- """Validates that `bot/resources/stars.json` contains valid images."""
-
- path = Path('bot', 'resources', 'stars.json')
- content = path.read_text()
- data = json.loads(content)
-
- for url in data.values():
- assert urlparse(url).scheme == 'https'
-
- mimetype, _ = mimetypes.guess_type(url)
- assert mimetype in ('image/jpeg', 'image/png')
diff --git a/tests/utils/test_checks.py b/tests/utils/test_checks.py
deleted file mode 100644
index 7121acebd..000000000
--- a/tests/utils/test_checks.py
+++ /dev/null
@@ -1,66 +0,0 @@
-from unittest.mock import MagicMock
-
-import pytest
-
-from bot.utils import checks
-
-
-def context():
- return MagicMock()
-
-
-def test_with_role_check_without_guild(context):
- context.guild = None
-
- assert not checks.with_role_check(context)
-
-
-def test_with_role_check_with_guild_without_required_role(context):
- context.guild = True
- context.author.roles = []
-
- assert not checks.with_role_check(context)
-
-
-def test_with_role_check_with_guild_with_required_role(context):
- context.guild = True
- role = MagicMock()
- role.id = 42
- context.author.roles = (role,)
-
- assert checks.with_role_check(context, role.id)
-
-
-def test_without_role_check_without_guild(context):
- context.guild = None
-
- assert not checks.without_role_check(context)
-
-
-def test_without_role_check_with_unwanted_role(context):
- context.guild = True
- role = MagicMock()
- role.id = 42
- context.author.roles = (role,)
-
- assert not checks.without_role_check(context, role.id)
-
-
-def test_without_role_check_without_unwanted_role(context):
- context.guild = True
- role = MagicMock()
- role.id = 42
- context.author.roles = (role,)
-
- assert checks.without_role_check(context, role.id + 10)
-
-
-def test_in_channel_check_for_correct_channel(context):
- context.channel.id = 42
- assert checks.in_channel_check(context, context.channel.id)
-
-
-def test_in_channel_check_for_incorrect_channel(context):
- context.channel.id = 42
- assert not checks.in_channel_check(context, context.channel.id + 10)
diff --git a/tests/utils/test_time.py b/tests/utils/test_time.py
new file mode 100644
index 000000000..4baa6395c
--- /dev/null
+++ b/tests/utils/test_time.py
@@ -0,0 +1,62 @@
+import asyncio
+from datetime import datetime, timezone
+from unittest.mock import patch
+
+import pytest
+from dateutil.relativedelta import relativedelta
+
+from bot.utils import time
+from tests.helpers import AsyncMock
+
+
+ ('delta', 'precision', 'max_units', 'expected'),
+ (
+ (relativedelta(days=2), 'seconds', 1, '2 days'),
+ (relativedelta(days=2, hours=2), 'seconds', 2, '2 days and 2 hours'),
+ (relativedelta(days=2, hours=2), 'seconds', 1, '2 days'),
+ (relativedelta(days=2, hours=2), 'days', 2, '2 days'),
+
+ # Does not abort for unknown units, as the unit name is checked
+ # against the attribute of the relativedelta instance.
+ (relativedelta(days=2, hours=2), 'elephants', 2, '2 days and 2 hours'),
+
+ # Very high maximum units, but it only ever iterates over
+ # each value the relativedelta might have.
+ (relativedelta(days=2, hours=2), 'hours', 20, '2 days and 2 hours'),
+ )
+)
+def test_humanize_delta(
+ delta: relativedelta,
+ precision: str,
+ max_units: int,
+ expected: str
+):
+ assert time.humanize_delta(delta, precision, max_units) == expected
+
+
[email protected]('max_units', (-1, 0))
+def test_humanize_delta_raises_for_invalid_max_units(max_units: int):
+ with pytest.raises(ValueError, match='max_units must be positive'):
+ time.humanize_delta(relativedelta(days=2, hours=2), 'hours', max_units)
+
+
+ ('stamp', 'expected'),
+ (
+ ('Sun, 15 Sep 2019 12:00:00 GMT', datetime(2019, 9, 15, 12, 0, 0, tzinfo=timezone.utc)),
+ )
+)
+def test_parse_rfc1123(stamp: str, expected: str):
+ assert time.parse_rfc1123(stamp) == expected
+
+
+@patch('asyncio.sleep', new_callable=AsyncMock)
+def test_wait_until(sleep_patch):
+ start = datetime(2019, 1, 1, 0, 0)
+ then = datetime(2019, 1, 1, 0, 10)
+
+ # No return value
+ assert asyncio.run(time.wait_until(then, start)) is None
+
+ sleep_patch.assert_called_once_with(10 * 60)