aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar S. Co1 <[email protected]>2020-02-23 14:09:38 -0500
committerGravatar S. Co1 <[email protected]>2020-02-23 14:09:38 -0500
commitea83fb74ff8bd252c70cfca214e9567b364712ba (patch)
tree2b12c43088bec679fba8d349cbd27b9721ab6756
parentAdd missed signature reformat from review (diff)
parentAdd Sentdex server to whitelist (diff)
Merge branch 'master' into reminder-enhancements
-rw-r--r--Pipfile4
-rw-r--r--Pipfile.lock121
-rw-r--r--azure-pipelines.yml2
-rw-r--r--bot/__init__.py87
-rw-r--r--bot/__main__.py13
-rw-r--r--bot/api.py74
-rw-r--r--bot/bot.py2
-rw-r--r--bot/cogs/defcon.py4
-rw-r--r--bot/cogs/error_handler.py25
-rw-r--r--bot/cogs/information.py109
-rw-r--r--bot/cogs/moderation/management.py7
-rw-r--r--bot/cogs/moderation/scheduler.py18
-rw-r--r--bot/cogs/reminders.py56
-rw-r--r--bot/cogs/tags.py6
-rw-r--r--bot/cogs/verification.py48
-rw-r--r--bot/constants.py2
-rw-r--r--bot/pagination.py42
-rw-r--r--bot/rules/attachments.py2
-rw-r--r--bot/utils/time.py36
-rw-r--r--config-default.yml4
-rw-r--r--docker-compose.yml2
-rw-r--r--tests/bot/cogs/test_information.py14
-rw-r--r--tests/bot/rules/__init__.py76
-rw-r--r--tests/bot/rules/test_attachments.py97
-rw-r--r--tests/bot/rules/test_burst.py56
-rw-r--r--tests/bot/rules/test_burst_shared.py59
-rw-r--r--tests/bot/rules/test_chars.py66
-rw-r--r--tests/bot/rules/test_discord_emojis.py54
-rw-r--r--tests/bot/rules/test_duplicates.py66
-rw-r--r--tests/bot/rules/test_links.py84
-rw-r--r--tests/bot/rules/test_mentions.py90
-rw-r--r--tests/bot/rules/test_newlines.py105
-rw-r--r--tests/bot/rules/test_role_mentions.py57
-rw-r--r--tests/bot/test_api.py64
-rw-r--r--tox.ini6
35 files changed, 951 insertions, 607 deletions
diff --git a/Pipfile b/Pipfile
index 7fd3efae8..400e64c18 100644
--- a/Pipfile
+++ b/Pipfile
@@ -6,7 +6,6 @@ name = "pypi"
[packages]
discord-py = "~=1.3.1"
aiodns = "~=2.0"
-logmatic-python = "~=0.1"
aiohttp = "~=3.5"
sphinx = "~=2.2"
markdownify = "~=0.4"
@@ -19,11 +18,12 @@ deepdiff = "~=4.0"
requests = "~=2.22"
more_itertools = "~=7.2"
urllib3 = ">=1.24.2,<1.25"
+sentry-sdk = "~=0.14"
[dev-packages]
coverage = "~=4.5"
flake8 = "~=3.7"
-flake8-annotations = "~=1.1"
+flake8-annotations = "~=2.0"
flake8-bugbear = "~=19.8"
flake8-docstrings = "~=1.4"
flake8-import-order = "~=0.18"
diff --git a/Pipfile.lock b/Pipfile.lock
index bf8ff47e9..fa29bf995 100644
--- a/Pipfile.lock
+++ b/Pipfile.lock
@@ -1,7 +1,7 @@
{
"_meta": {
"hash": {
- "sha256": "0a0354a8cbd25b19c61b68f928493a445e737dc6447c97f4c4b52fbf72d887ac"
+ "sha256": "c7706a61eb96c06d073898018ea2dbcf5bd3b15d007496e2d60120a65647f31e"
},
"pipfile-spec": 6,
"requires": {
@@ -18,11 +18,11 @@
"default": {
"aio-pika": {
"hashes": [
- "sha256:a5837277e53755078db3a9e8c45bbca605c8ba9ecba7a02d74a7a1779f444723",
- "sha256:fa32e33b4b7d0804dcf439ae6ff24d2f0a83d1ba280ee9f555e647d71d394ff5"
+ "sha256:4199122a450dffd8303b7857a9d82657bf1487fe329e489520833b40fbe92406",
+ "sha256:fe85c7456e5c060bce4eb9cffab5b2c4d3c563cb72177977b3556c54c8e3aeb6"
],
"index": "pypi",
- "version": "==6.4.1"
+ "version": "==6.5.2"
},
"aiodns": {
"hashes": [
@@ -52,10 +52,10 @@
},
"aiormq": {
"hashes": [
- "sha256:8c215a970133ab5ee7c478decac55b209af7731050f52d11439fe910fa0f9e9d",
- "sha256:9210f3389200aee7d8067f6435f4a9eff2d3a30b88beb5eaae406ccc11c0fc01"
+ "sha256:286e0b0772075580466e45f98f051b9728a9316b9c36f0c14c7bc1409be375b0",
+ "sha256:7ed7d6df6b57af7f8bce7d1ebcbdfc32b676192e46703e81e9e217316e56b5bd"
],
- "version": "==3.2.0"
+ "version": "==3.2.1"
},
"alabaster": {
"hashes": [
@@ -164,18 +164,18 @@
},
"fuzzywuzzy": {
"hashes": [
- "sha256:5ac7c0b3f4658d2743aa17da53a55598144edbc5bee3c6863840636e6926f254",
- "sha256:6f49de47db00e1c71d40ad16da42284ac357936fa9b66bea1df63fed07122d62"
+ "sha256:45016e92264780e58972dca1b3d939ac864b78437422beecebb3095f8efd00e8",
+ "sha256:928244b28db720d1e0ee7587acf660ea49d7e4c632569cad4f1cd7e68a5f0993"
],
"index": "pypi",
- "version": "==0.17.0"
+ "version": "==0.18.0"
},
"idna": {
"hashes": [
- "sha256:c357b3f628cf53ae2c4c05627ecc484553142ca23264e593d327bcde5e9c3407",
- "sha256:ea8b7f6188e6fa117537c3df7da9fc686d485087abf6ac197f9c46432f7e4a3c"
+ "sha256:7588d1c14ae4c77d74036e8c22ff447b26d0fde8f007354fd48a7814db15b7cb",
+ "sha256:a068a21ceac8a4d63dbfd964670474107f541babbd2250d61922f029858365fa"
],
- "version": "==2.8"
+ "version": "==2.9"
},
"imagesize": {
"hashes": [
@@ -191,13 +191,6 @@
],
"version": "==2.11.1"
},
- "logmatic-python": {
- "hashes": [
- "sha256:0c15ac9f5faa6a60059b28910db642c3dc7722948c3cc940923f8c9039604342"
- ],
- "index": "pypi",
- "version": "==0.1.7"
- },
"lxml": {
"hashes": [
"sha256:06d4e0bbb1d62e38ae6118406d7cdb4693a3fa34ee3762238bcb96c9e36a93cd",
@@ -388,12 +381,6 @@
"index": "pypi",
"version": "==2.8.1"
},
- "python-json-logger": {
- "hashes": [
- "sha256:b7a31162f2a01965a5efb94453ce69230ed208468b0bbc7fdfc56e6d8df2e281"
- ],
- "version": "==0.1.11"
- },
"pytz": {
"hashes": [
"sha256:1c557d7d0e871de1f5ccd5833f60fb2550652da6be2693c1e02300743d21500d",
@@ -420,11 +407,19 @@
},
"requests": {
"hashes": [
- "sha256:11e007a8a2aa0323f5a921e9e6a2d7e4e67d9877e85773fba9ba6419025cbeb4",
- "sha256:9cf5292fcd0f598c671cfc1e0d7d1a7f13bb8085e9a590f48c010551dc6c4b31"
+ "sha256:43999036bfa82904b6af1d99e4882b560e5e2c68e5c4b0aa03b655f3d7d73fee",
+ "sha256:b3f43d496c6daba4493e7c431722aeb7dbc6288f52a6e04e7b6023b0247817e6"
+ ],
+ "index": "pypi",
+ "version": "==2.23.0"
+ },
+ "sentry-sdk": {
+ "hashes": [
+ "sha256:b06dd27391fd11fb32f84fe054e6a64736c469514a718a99fb5ce1dff95d6b28",
+ "sha256:e023da07cfbead3868e1e2ba994160517885a32dfd994fc455b118e37989479b"
],
"index": "pypi",
- "version": "==2.22.0"
+ "version": "==0.14.1"
},
"six": {
"hashes": [
@@ -449,11 +444,11 @@
},
"sphinx": {
"hashes": [
- "sha256:298537cb3234578b2d954ff18c5608468229e116a9757af3b831c2b2b4819159",
- "sha256:e6e766b74f85f37a5f3e0773a1e1be8db3fcb799deb58ca6d18b70b0b44542a5"
+ "sha256:525527074f2e0c2585f68f73c99b4dc257c34bbe308b27f5f8c7a6e20642742f",
+ "sha256:543d39db5f82d83a5c1aa0c10c88f2b6cff2da3e711aa849b2c627b4b403bbd9"
],
"index": "pypi",
- "version": "==2.3.1"
+ "version": "==2.4.2"
},
"sphinxcontrib-applehelp": {
"hashes": [
@@ -556,6 +551,13 @@
}
},
"develop": {
+ "appdirs": {
+ "hashes": [
+ "sha256:9e5896d1372858f8dd3344faf4e5014d21849c756c8d5701f78f8a103b372d92",
+ "sha256:d8b24664561d0d34ddfaec54636d502d7cea6e29c3eaf68f3df6180863e2166e"
+ ],
+ "version": "==1.4.3"
+ },
"aspy.yaml": {
"hashes": [
"sha256:463372c043f70160a9ec950c3f1e4c3a82db5fca01d334b6bc89c7164d744bdc",
@@ -579,10 +581,10 @@
},
"cfgv": {
"hashes": [
- "sha256:edb387943b665bf9c434f717bf630fa78aecd53d5900d2e05da6ad6048553144",
- "sha256:fbd93c9ab0a523bf7daec408f3be2ed99a980e20b2d19b50fc184ca6b820d289"
+ "sha256:04b093b14ddf9fd4d17c53ebfd55582d27b76ed30050193c14e560770c5360eb",
+ "sha256:f22b426ed59cd2ab2b54ff96608d846c33dfb8766a67f0b4a6ce130ce244414f"
],
- "version": "==2.0.1"
+ "version": "==3.0.0"
},
"chardet": {
"hashes": [
@@ -636,6 +638,12 @@
"index": "pypi",
"version": "==4.5.4"
},
+ "distlib": {
+ "hashes": [
+ "sha256:2e166e231a26b36d6dfe35a48c4464346620f8645ed0ace01ee31822b288de21"
+ ],
+ "version": "==0.3.0"
+ },
"dodgy": {
"hashes": [
"sha256:28323cbfc9352139fdd3d316fa17f325cc0e9ac74438cbba51d70f9b48f86c3a",
@@ -658,6 +666,13 @@
],
"version": "==0.3"
},
+ "filelock": {
+ "hashes": [
+ "sha256:18d82244ee114f543149c66a6e0c14e9c4f8a1044b5cdaadd0f82159d6a6ff59",
+ "sha256:929b7d63ec5b7d6b71b0fa5ac14e030b3f70b75747cef1b10da9b879fef15836"
+ ],
+ "version": "==3.0.12"
+ },
"flake8": {
"hashes": [
"sha256:45681a117ecc81e870cbf1262835ae4af5e7a8b08e40b944a8a6e6b895914cfb",
@@ -668,11 +683,11 @@
},
"flake8-annotations": {
"hashes": [
- "sha256:05b85538014c850a86dce7374bb6621c64481c24e35e8e90af1315f4d7a3dbaa",
- "sha256:43e5233a76fda002b91a54a7cc4510f099c4bfd6279502ec70164016250eebd1"
+ "sha256:19a6637a5da1bb7ea7948483ca9e2b9e15b213e687e7bf5ff8c1bfc91c185006",
+ "sha256:bb033b72cdd3a2b0a530bbdf2081f12fbea7d70baeaaebb5899723a45f424b8e"
],
"index": "pypi",
- "version": "==1.1.3"
+ "version": "==2.0.0"
},
"flake8-bugbear": {
"hashes": [
@@ -700,11 +715,11 @@
},
"flake8-string-format": {
"hashes": [
- "sha256:68ea72a1a5b75e7018cae44d14f32473c798cf73d75cbaed86c6a9a907b770b2",
- "sha256:774d56103d9242ed968897455ef49b7d6de272000cfa83de5814273a868832f1"
+ "sha256:65f3da786a1461ef77fca3780b314edb2853c377f2e35069723348c8917deaa2",
+ "sha256:812ff431f10576a74c89be4e85b8e075a705be39bc40c4b4278b5b13e2afa9af"
],
"index": "pypi",
- "version": "==0.2.3"
+ "version": "==0.3.0"
},
"flake8-tidy-imports": {
"hashes": [
@@ -730,10 +745,10 @@
},
"idna": {
"hashes": [
- "sha256:c357b3f628cf53ae2c4c05627ecc484553142ca23264e593d327bcde5e9c3407",
- "sha256:ea8b7f6188e6fa117537c3df7da9fc686d485087abf6ac197f9c46432f7e4a3c"
+ "sha256:7588d1c14ae4c77d74036e8c22ff447b26d0fde8f007354fd48a7814db15b7cb",
+ "sha256:a068a21ceac8a4d63dbfd964670474107f541babbd2250d61922f029858365fa"
],
- "version": "==2.8"
+ "version": "==2.9"
},
"importlib-metadata": {
"hashes": [
@@ -818,11 +833,11 @@
},
"requests": {
"hashes": [
- "sha256:11e007a8a2aa0323f5a921e9e6a2d7e4e67d9877e85773fba9ba6419025cbeb4",
- "sha256:9cf5292fcd0f598c671cfc1e0d7d1a7f13bb8085e9a590f48c010551dc6c4b31"
+ "sha256:43999036bfa82904b6af1d99e4882b560e5e2c68e5c4b0aa03b655f3d7d73fee",
+ "sha256:b3f43d496c6daba4493e7c431722aeb7dbc6288f52a6e04e7b6023b0247817e6"
],
"index": "pypi",
- "version": "==2.22.0"
+ "version": "==2.23.0"
},
"safety": {
"hashes": [
@@ -898,17 +913,17 @@
},
"virtualenv": {
"hashes": [
- "sha256:0d62c70883c0342d59c11d0ddac0d954d0431321a41ab20851facf2b222598f3",
- "sha256:55059a7a676e4e19498f1aad09b8313a38fcc0cdbe4fdddc0e9b06946d21b4bb"
+ "sha256:08f3623597ce73b85d6854fb26608a6f39ee9d055c81178dc6583803797f8994",
+ "sha256:de2cbdd5926c48d7b84e0300dea9e8f276f61d186e8e49223d71d91250fbaebd"
],
- "version": "==16.7.9"
+ "version": "==20.0.4"
},
"zipp": {
"hashes": [
- "sha256:ccc94ed0909b58ffe34430ea5451f07bc0c76467d7081619a454bf5c98b89e28",
- "sha256:feae2f18633c32fc71f2de629bfb3bd3c9325cd4419642b1f1da42ee488d9b98"
+ "sha256:12248a63bbdf7548f89cb4c7cda4681e537031eda29c02ea29674bc6854460c2",
+ "sha256:7c0f8e91abc0dc07a5068f315c52cb30c66bfbc581e5b50704c8a2f6ebae794a"
],
- "version": "==2.1.0"
+ "version": "==3.0.0"
}
}
}
diff --git a/azure-pipelines.yml b/azure-pipelines.yml
index 0400ac4d2..874364a6f 100644
--- a/azure-pipelines.yml
+++ b/azure-pipelines.yml
@@ -30,7 +30,7 @@ jobs:
- script: python -m flake8
displayName: 'Run linter'
- - script: BOT_API_KEY=foo BOT_TOKEN=bar WOLFRAM_API_KEY=baz REDDIT_CLIENT_ID=spam REDDIT_SECRET=ham coverage run -m xmlrunner
+ - script: BOT_API_KEY=foo BOT_SENTRY_DSN=blah BOT_TOKEN=bar WOLFRAM_API_KEY=baz REDDIT_CLIENT_ID=spam REDDIT_SECRET=ham coverage run -m xmlrunner
displayName: Run tests
- script: coverage report -m && coverage xml -o coverage.xml
diff --git a/bot/__init__.py b/bot/__init__.py
index 789ace5c0..f7a410706 100644
--- a/bot/__init__.py
+++ b/bot/__init__.py
@@ -4,11 +4,8 @@ import sys
from logging import Logger, StreamHandler, handlers
from pathlib import Path
-from logmatic import JsonFormatter
-
-
-logging.TRACE = 5
-logging.addLevelName(logging.TRACE, "TRACE")
+TRACE_LEVEL = logging.TRACE = 5
+logging.addLevelName(TRACE_LEVEL, "TRACE")
def monkeypatch_trace(self: logging.Logger, msg: str, *args, **kwargs) -> None:
@@ -20,75 +17,29 @@ def monkeypatch_trace(self: logging.Logger, msg: str, *args, **kwargs) -> None:
logger.trace("Houston, we have an %s", "interesting problem", exc_info=1)
"""
- if self.isEnabledFor(logging.TRACE):
- self._log(logging.TRACE, msg, args, **kwargs)
+ if self.isEnabledFor(TRACE_LEVEL):
+ self._log(TRACE_LEVEL, msg, args, **kwargs)
Logger.trace = monkeypatch_trace
-# Set up logging
-logging_handlers = []
-
-# We can't import this yet, so we have to define it ourselves
-DEBUG_MODE = True if 'local' in os.environ.get("SITE_URL", "local") else False
-
-LOG_DIR = Path("logs")
-LOG_DIR.mkdir(exist_ok=True)
-
-if DEBUG_MODE:
- logging_handlers.append(StreamHandler(stream=sys.stdout))
-
- json_handler = logging.FileHandler(filename=Path(LOG_DIR, "log.json"), mode="w")
- json_handler.formatter = JsonFormatter()
- logging_handlers.append(json_handler)
-else:
-
- logfile = Path(LOG_DIR, "bot.log")
- megabyte = 1048576
-
- filehandler = handlers.RotatingFileHandler(logfile, maxBytes=(megabyte*5), backupCount=7)
- logging_handlers.append(filehandler)
-
- json_handler = logging.StreamHandler(stream=sys.stdout)
- json_handler.formatter = JsonFormatter()
- logging_handlers.append(json_handler)
-
-
-logging.basicConfig(
- format="%(asctime)s Bot: | %(name)33s | %(levelname)8s | %(message)s",
- datefmt="%b %d %H:%M:%S",
- level=logging.TRACE if DEBUG_MODE else logging.INFO,
- handlers=logging_handlers
-)
-
-log = logging.getLogger(__name__)
-
-
-for key, value in logging.Logger.manager.loggerDict.items():
- # Force all existing loggers to the correct level and handlers
- # This happens long before we instantiate our loggers, so
- # those should still have the expected level
-
- if key == "bot":
- continue
-
- if not isinstance(value, logging.Logger):
- # There might be some logging.PlaceHolder objects in there
- continue
+DEBUG_MODE = 'local' in os.environ.get("SITE_URL", "local")
- if DEBUG_MODE:
- value.setLevel(logging.DEBUG)
- else:
- value.setLevel(logging.INFO)
+log_format = logging.Formatter("%(asctime)s | %(name)s | %(levelname)s | %(message)s")
- for handler in value.handlers.copy():
- value.removeHandler(handler)
+stream_handler = StreamHandler(stream=sys.stdout)
+stream_handler.setFormatter(log_format)
- for handler in logging_handlers:
- value.addHandler(handler)
+log_file = Path("logs", "bot.log")
+log_file.parent.mkdir(exist_ok=True)
+file_handler = handlers.RotatingFileHandler(log_file, maxBytes=5242880, backupCount=7)
+file_handler.setFormatter(log_format)
+root_log = logging.getLogger()
+root_log.setLevel(TRACE_LEVEL if DEBUG_MODE else logging.INFO)
+root_log.addHandler(stream_handler)
+root_log.addHandler(file_handler)
-# Silence irrelevant loggers
-logging.getLogger("aio_pika").setLevel(logging.ERROR)
-logging.getLogger("discord").setLevel(logging.ERROR)
-logging.getLogger("websockets").setLevel(logging.ERROR)
+logging.getLogger("discord").setLevel(logging.WARNING)
+logging.getLogger("websockets").setLevel(logging.WARNING)
+logging.getLogger(__name__)
diff --git a/bot/__main__.py b/bot/__main__.py
index 84bc7094b..490163739 100644
--- a/bot/__main__.py
+++ b/bot/__main__.py
@@ -1,10 +1,23 @@
+import logging
+
import discord
+import sentry_sdk
from discord.ext.commands import when_mentioned_or
+from sentry_sdk.integrations.logging import LoggingIntegration
from bot import patches
from bot.bot import Bot
from bot.constants import Bot as BotConfig, DEBUG_MODE
+sentry_logging = LoggingIntegration(
+ level=logging.TRACE,
+ event_level=logging.WARNING
+)
+
+sentry_sdk.init(
+ dsn=BotConfig.sentry_dsn,
+ integrations=[sentry_logging]
+)
bot = Bot(
command_prefix=when_mentioned_or(BotConfig.prefix),
diff --git a/bot/api.py b/bot/api.py
index 56db99828..fb126b384 100644
--- a/bot/api.py
+++ b/bot/api.py
@@ -141,77 +141,3 @@ def loop_is_running() -> bool:
except RuntimeError:
return False
return True
-
-
-class APILoggingHandler(logging.StreamHandler):
- """Site API logging handler."""
-
- def __init__(self, client: APIClient):
- logging.StreamHandler.__init__(self)
- self.client = client
-
- # internal batch of shipoff tasks that must not be scheduled
- # on the event loop yet - scheduled when the event loop is ready.
- self.queue = []
-
- async def ship_off(self, payload: dict) -> None:
- """Ship log payload to the logging API."""
- try:
- await self.client.post('logs', json=payload)
- except ResponseCodeError as err:
- log.warning(
- "Cannot send logging record to the site, got code %d.",
- err.response.status,
- extra={'via_handler': True}
- )
- except Exception as err:
- log.warning(
- "Cannot send logging record to the site: %r",
- err,
- extra={'via_handler': True}
- )
-
- def emit(self, record: logging.LogRecord) -> None:
- """
- Determine if a log record should be shipped to the logging API.
-
- If the asyncio event loop is not yet running, log records will instead be put in a queue
- which will be consumed once the event loop is running.
-
- The following two conditions are set:
- 1. Do not log anything below DEBUG (only applies to the monkeypatched `TRACE` level)
- 2. Ignore log records originating from this logging handler itself to prevent infinite recursion
- """
- if (
- record.levelno >= logging.DEBUG
- and not record.__dict__.get('via_handler')
- ):
- payload = {
- 'application': 'bot',
- 'logger_name': record.name,
- 'level': record.levelname.lower(),
- 'module': record.module,
- 'line': record.lineno,
- 'message': self.format(record)
- }
-
- task = self.ship_off(payload)
- if not loop_is_running():
- self.queue.append(task)
- else:
- asyncio.create_task(task)
- self.schedule_queued_tasks()
-
- def schedule_queued_tasks(self) -> None:
- """Consume the queue and schedule the logging of each queued record."""
- for task in self.queue:
- asyncio.create_task(task)
-
- if self.queue:
- log.debug(
- "Scheduled %d pending logging tasks.",
- len(self.queue),
- extra={'via_handler': True}
- )
-
- self.queue.clear()
diff --git a/bot/bot.py b/bot/bot.py
index 8f808272f..cecee7b68 100644
--- a/bot/bot.py
+++ b/bot/bot.py
@@ -27,8 +27,6 @@ class Bot(commands.Bot):
self.http_session: Optional[aiohttp.ClientSession] = None
self.api_client = api.APIClient(loop=self.loop, connector=self.connector)
- log.addHandler(api.APILoggingHandler(self.api_client))
-
def add_cog(self, cog: commands.Cog) -> None:
"""Adds a "cog" to the bot and logs the operation."""
super().add_cog(cog)
diff --git a/bot/cogs/defcon.py b/bot/cogs/defcon.py
index 3e7350fcc..a0d8fedd5 100644
--- a/bot/cogs/defcon.py
+++ b/bot/cogs/defcon.py
@@ -76,12 +76,12 @@ class Defcon(Cog):
if data["enabled"]:
self.enabled = True
self.days = timedelta(days=data["days"])
- log.warning(f"DEFCON enabled: {self.days.days} days")
+ log.info(f"DEFCON enabled: {self.days.days} days")
else:
self.enabled = False
self.days = timedelta(days=0)
- log.warning(f"DEFCON disabled")
+ log.info(f"DEFCON disabled")
await self.update_channel_topic()
diff --git a/bot/cogs/error_handler.py b/bot/cogs/error_handler.py
index 52893b2ee..0abb7e521 100644
--- a/bot/cogs/error_handler.py
+++ b/bot/cogs/error_handler.py
@@ -15,6 +15,7 @@ from discord.ext.commands import (
UserInputError,
)
from discord.ext.commands import Cog, Context
+from sentry_sdk import push_scope
from bot.api import ResponseCodeError
from bot.bot import Bot
@@ -147,10 +148,26 @@ class ErrorHandler(Cog):
f"Sorry, an unexpected error occurred. Please let us know!\n\n"
f"```{e.__class__.__name__}: {e}```"
)
- log.error(
- f"Error executing command invoked by {ctx.message.author}: {ctx.message.content}"
- )
- raise e
+
+ with push_scope() as scope:
+ scope.user = {
+ "id": ctx.author.id,
+ "username": str(ctx.author)
+ }
+
+ scope.set_tag("command", ctx.command.qualified_name)
+ scope.set_tag("message_id", ctx.message.id)
+ scope.set_tag("channel_id", ctx.channel.id)
+
+ scope.set_extra("full_message", ctx.message.content)
+
+ if ctx.guild is not None:
+ scope.set_extra(
+ "jump_to",
+ f"https://discordapp.com/channels/{ctx.guild.id}/{ctx.channel.id}/{ctx.message.id}"
+ )
+
+ log.error(f"Error executing command invoked by {ctx.message.author}: {ctx.message.content}", exc_info=e)
def setup(bot: Bot) -> None:
diff --git a/bot/cogs/information.py b/bot/cogs/information.py
index 125d7ce24..13c8aabaa 100644
--- a/bot/cogs/information.py
+++ b/bot/cogs/information.py
@@ -2,14 +2,12 @@ import colorsys
import logging
import pprint
import textwrap
-import typing
-from collections import defaultdict
-from typing import Any, Mapping, Optional
-
-import discord
-from discord import CategoryChannel, Colour, Embed, Member, Role, TextChannel, VoiceChannel, utils
-from discord.ext import commands
-from discord.ext.commands import BucketType, Cog, Context, command, group
+from collections import Counter, defaultdict
+from string import Template
+from typing import Any, Mapping, Optional, Union
+
+from discord import Colour, Embed, Member, Message, Role, Status, utils
+from discord.ext.commands import BucketType, Cog, Context, Paginator, command, group
from discord.utils import escape_markdown
from bot import constants
@@ -32,8 +30,7 @@ class Information(Cog):
async def roles_info(self, ctx: Context) -> None:
"""Returns a list of all roles and their corresponding IDs."""
# Sort the roles alphabetically and remove the @everyone role
- roles = sorted(ctx.guild.roles, key=lambda role: role.name)
- roles = [role for role in roles if role.name != "@everyone"]
+ roles = sorted(ctx.guild.roles[1:], key=lambda role: role.name)
# Build a string
role_string = ""
@@ -46,20 +43,20 @@ class Information(Cog):
colour=Colour.blurple(),
description=role_string
)
-
embed.set_footer(text=f"Total roles: {len(roles)}")
await ctx.send(embed=embed)
@with_role(*constants.MODERATION_ROLES)
@command(name="role")
- async def role_info(self, ctx: Context, *roles: typing.Union[Role, str]) -> None:
+ async def role_info(self, ctx: Context, *roles: Union[Role, str]) -> None:
"""
Return information on a role or list of roles.
To specify multiple roles just add to the arguments, delimit roles with spaces in them using quotation marks.
"""
parsed_roles = []
+ failed_roles = []
for role_name in roles:
if isinstance(role_name, Role):
@@ -70,29 +67,29 @@ class Information(Cog):
role = utils.find(lambda r: r.name.lower() == role_name.lower(), ctx.guild.roles)
if not role:
- await ctx.send(f":x: Could not convert `{role_name}` to a role")
+ failed_roles.append(role_name)
continue
parsed_roles.append(role)
+ if failed_roles:
+ await ctx.send(
+ ":x: I could not convert the following role names to a role: \n- "
+ "\n- ".join(failed_roles)
+ )
+
for role in parsed_roles:
+ h, s, v = colorsys.rgb_to_hsv(*role.colour.to_rgb())
+
embed = Embed(
title=f"{role.name} info",
colour=role.colour,
)
-
embed.add_field(name="ID", value=role.id, inline=True)
-
embed.add_field(name="Colour (RGB)", value=f"#{role.colour.value:0>6x}", inline=True)
-
- h, s, v = colorsys.rgb_to_hsv(*role.colour.to_rgb())
-
embed.add_field(name="Colour (HSV)", value=f"{h:.2f} {s:.2f} {v}", inline=True)
-
embed.add_field(name="Member count", value=len(role.members), inline=True)
-
embed.add_field(name="Position", value=role.position)
-
embed.add_field(name="Permission code", value=role.permissions.value, inline=True)
await ctx.send(embed=embed)
@@ -104,40 +101,23 @@ class Information(Cog):
features = ", ".join(ctx.guild.features)
region = ctx.guild.region
- # How many of each type of channel?
roles = len(ctx.guild.roles)
- channels = ctx.guild.channels
- text_channels = 0
- category_channels = 0
- voice_channels = 0
- for channel in channels:
- if type(channel) == TextChannel:
- text_channels += 1
- elif type(channel) == CategoryChannel:
- category_channels += 1
- elif type(channel) == VoiceChannel:
- voice_channels += 1
-
- # How many of each user status?
member_count = ctx.guild.member_count
- members = ctx.guild.members
- online = 0
- dnd = 0
- idle = 0
- offline = 0
- for member in members:
- if str(member.status) == "online":
- online += 1
- elif str(member.status) == "offline":
- offline += 1
- elif str(member.status) == "idle":
- idle += 1
- elif str(member.status) == "dnd":
- dnd += 1
- embed = Embed(
- colour=Colour.blurple(),
- description=textwrap.dedent(f"""
+ # How many of each type of channel?
+ channels = Counter(c.type for c in ctx.guild.channels)
+ channel_counts = "".join(sorted(f"{str(ch).title()} channels: {channels[ch]}\n" for ch in channels)).strip()
+
+ # How many of each user status?
+ statuses = Counter(member.status for member in ctx.guild.members)
+ embed = Embed(colour=Colour.blurple())
+
+ # Because channel_counts lacks leading whitespace, it breaks the dedent if it's inserted directly by the
+ # f-string. While this is correctly formated by Discord, it makes unit testing difficult. To keep the formatting
+ # without joining a tuple of strings we can use a Template string to insert the already-formatted channel_counts
+ # after the dedent is made.
+ embed.description = Template(
+ textwrap.dedent(f"""
**Server information**
Created: {created}
Voice region: {region}
@@ -146,18 +126,15 @@ class Information(Cog):
**Counts**
Members: {member_count:,}
Roles: {roles}
- Text: {text_channels}
- Voice: {voice_channels}
- Channel categories: {category_channels}
+ $channel_counts
**Members**
- {constants.Emojis.status_online} {online}
- {constants.Emojis.status_idle} {idle}
- {constants.Emojis.status_dnd} {dnd}
- {constants.Emojis.status_offline} {offline}
+ {constants.Emojis.status_online} {statuses[Status.online]:,}
+ {constants.Emojis.status_idle} {statuses[Status.idle]:,}
+ {constants.Emojis.status_dnd} {statuses[Status.dnd]:,}
+ {constants.Emojis.status_offline} {statuses[Status.offline]:,}
""")
- )
-
+ ).substitute({"channel_counts": channel_counts})
embed.set_thumbnail(url=ctx.guild.icon_url)
await ctx.send(embed=embed)
@@ -169,7 +146,7 @@ class Information(Cog):
user = ctx.author
# Do a role check if this is being executed on someone other than the caller
- if user != ctx.author and not with_role_check(ctx, *constants.MODERATION_ROLES):
+ elif user != ctx.author and not with_role_check(ctx, *constants.MODERATION_ROLES):
await ctx.send("You may not use this command on users other than yourself.")
return
@@ -202,7 +179,7 @@ class Information(Cog):
name = f"{user.nick} ({name})"
joined = time_since(user.joined_at, precision="days")
- roles = ", ".join(role.mention for role in user.roles if role.name != "@everyone")
+ roles = ", ".join(role.mention for role in user.roles[1:])
description = [
textwrap.dedent(f"""
@@ -356,13 +333,13 @@ class Information(Cog):
@cooldown_with_role_bypass(2, 60 * 3, BucketType.member, bypass_roles=constants.STAFF_ROLES)
@group(invoke_without_command=True)
@in_channel(constants.Channels.bot, bypass_roles=constants.STAFF_ROLES)
- async def raw(self, ctx: Context, *, message: discord.Message, json: bool = False) -> None:
+ async def raw(self, ctx: Context, *, message: Message, json: bool = False) -> None:
"""Shows information about the raw API response."""
# I *guess* it could be deleted right as the command is invoked but I felt like it wasn't worth handling
# doing this extra request is also much easier than trying to convert everything back into a dictionary again
raw_data = await ctx.bot.http.get_message(message.channel.id, message.id)
- paginator = commands.Paginator()
+ paginator = Paginator()
def add_content(title: str, content: str) -> None:
paginator.add_line(f'== {title} ==\n')
@@ -390,7 +367,7 @@ class Information(Cog):
await ctx.send(page)
@raw.command()
- async def json(self, ctx: Context, message: discord.Message) -> None:
+ async def json(self, ctx: Context, message: Message) -> None:
"""Shows information about the raw API response in a copy-pasteable Python format."""
await ctx.invoke(self.raw, message=message, json=True)
diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py
index 0636422d3..f2964cd78 100644
--- a/bot/cogs/moderation/management.py
+++ b/bot/cogs/moderation/management.py
@@ -130,8 +130,11 @@ class ModManagement(commands.Cog):
# Re-schedule infraction if the expiration has been updated
if 'expires_at' in request_data:
self.infractions_cog.cancel_task(new_infraction['id'])
- loop = asyncio.get_event_loop()
- self.infractions_cog.schedule_task(loop, new_infraction['id'], new_infraction)
+
+ # If the infraction was not marked as permanent, schedule a new expiration task
+ if request_data['expires_at']:
+ loop = asyncio.get_event_loop()
+ self.infractions_cog.schedule_task(loop, new_infraction['id'], new_infraction)
log_text += f"""
Previous expiry: {old_infraction['expires_at'] or "Permanent"}
diff --git a/bot/cogs/moderation/scheduler.py b/bot/cogs/moderation/scheduler.py
index e14c302cb..c0de0e4da 100644
--- a/bot/cogs/moderation/scheduler.py
+++ b/bot/cogs/moderation/scheduler.py
@@ -309,16 +309,23 @@ class InfractionScheduler(Scheduler):
guild = self.bot.get_guild(constants.Guild.id)
mod_role = guild.get_role(constants.Roles.moderator)
user_id = infraction["user"]
+ actor = infraction["actor"]
type_ = infraction["type"]
id_ = infraction["id"]
+ inserted_at = infraction["inserted_at"]
+ expiry = infraction["expires_at"]
log.info(f"Marking infraction #{id_} as inactive (expired).")
+ expiry = dateutil.parser.isoparse(expiry).replace(tzinfo=None) if expiry else None
+ created = time.format_infraction_with_duration(inserted_at, expiry)
+
log_content = None
log_text = {
- "Member": str(user_id),
- "Actor": str(self.bot.user),
- "Reason": infraction["reason"]
+ "Member": f"<@{user_id}>",
+ "Actor": str(self.bot.get_user(actor) or actor),
+ "Reason": infraction["reason"],
+ "Created": created,
}
try:
@@ -384,14 +391,19 @@ class InfractionScheduler(Scheduler):
if send_log:
log_title = f"expiration failed" if "Failure" in log_text else "expired"
+ user = self.bot.get_user(user_id)
+ avatar = user.avatar_url_as(static_format="png") if user else None
+
log.trace(f"Sending deactivation mod log for infraction #{id_}.")
await self.mod_log.send_log_message(
icon_url=utils.INFRACTION_ICONS[type_][1],
colour=Colours.soft_green,
title=f"Infraction {log_title}: {type_}",
+ thumbnail=avatar,
text="\n".join(f"{k}: {v}" for k, v in log_text.items()),
footer=f"ID: {id_}",
content=log_content,
+
)
return log_text
diff --git a/bot/cogs/reminders.py b/bot/cogs/reminders.py
index ff803baf8..42229123b 100644
--- a/bot/cogs/reminders.py
+++ b/bot/cogs/reminders.py
@@ -2,13 +2,13 @@ import asyncio
import logging
import random
import textwrap
+import typing as t
from datetime import datetime, timedelta
from operator import itemgetter
-from typing import Optional
+import discord
from dateutil.parser import isoparse
from dateutil.relativedelta import relativedelta
-from discord import Colour, Embed, Message
from discord.ext.commands import Cog, Context, group
from bot.bot import Bot
@@ -46,6 +46,10 @@ class Reminders(Scheduler, Cog):
loop = asyncio.get_event_loop()
for reminder in response:
+ is_valid, *_ = self.ensure_valid_reminder(reminder)
+ if not is_valid:
+ continue
+
remind_at = isoparse(reminder['expiration']).replace(tzinfo=None)
# If the reminder is already overdue ...
@@ -55,16 +59,31 @@ class Reminders(Scheduler, Cog):
else:
self.schedule_task(loop, reminder["id"], reminder)
+ def ensure_valid_reminder(self, reminder: dict) -> t.Tuple[bool, discord.User, discord.TextChannel]:
+ """Ensure reminder author and channel can be fetched otherwise delete the reminder."""
+ user = self.bot.get_user(reminder['author'])
+ channel = self.bot.get_channel(reminder['channel_id'])
+ is_valid = True
+ if not user or not channel:
+ is_valid = False
+ log.info(
+ f"Reminder {reminder['id']} invalid: "
+ f"User {reminder['author']}={user}, Channel {reminder['channel_id']}={channel}."
+ )
+ asyncio.create_task(self._delete_reminder(reminder['id']))
+
+ return is_valid, user, channel
+
@staticmethod
async def _send_confirmation(
ctx: Context,
on_success: str,
reminder_id: str,
- delivery_dt: Optional[datetime],
+ delivery_dt: t.Optional[datetime],
) -> None:
"""Send an embed confirming the reminder change was made successfully."""
- embed = Embed()
- embed.colour = Colour.green()
+ embed = discord.Embed()
+ embed.colour = discord.Colour.green()
embed.title = random.choice(POSITIVE_REPLIES)
embed.description = on_success
@@ -108,11 +127,12 @@ class Reminders(Scheduler, Cog):
async def send_reminder(self, reminder: dict, late: relativedelta = None) -> None:
"""Send the reminder."""
- channel = self.bot.get_channel(reminder["channel_id"])
- user = self.bot.get_user(reminder["author"])
+ is_valid, user, channel = self.ensure_valid_reminder(reminder)
+ if not is_valid:
+ return
- embed = Embed()
- embed.colour = Colour.blurple()
+ embed = discord.Embed()
+ embed.colour = discord.Colour.blurple()
embed.set_author(
icon_url=Icons.remind_blurple,
name="It has arrived!"
@@ -124,7 +144,7 @@ class Reminders(Scheduler, Cog):
embed.description += f"\n[Jump back to when you created the reminder]({reminder['jump_url']})"
if late:
- embed.colour = Colour.red()
+ embed.colour = discord.Colour.red()
embed.set_author(
icon_url=Icons.remind_red,
name=f"Sorry it arrived {humanize_delta(late, max_units=2)} late!"
@@ -142,20 +162,20 @@ class Reminders(Scheduler, Cog):
await ctx.invoke(self.new_reminder, expiration=expiration, content=content)
@remind_group.command(name="new", aliases=("add", "create"))
- async def new_reminder(self, ctx: Context, expiration: Duration, *, content: str) -> Optional[Message]:
+ async def new_reminder(self, ctx: Context, expiration: Duration, *, content: str) -> t.Optional[discord.Message]:
"""
Set yourself a simple reminder.
Expiration is parsed per: http://strftime.org/
"""
- embed = Embed()
+ embed = discord.Embed()
# If the user is not staff, we need to verify whether or not to make a reminder at all.
if without_role_check(ctx, *STAFF_ROLES):
# If they don't have permission to set a reminder in this channel
if ctx.channel.id not in WHITELISTED_CHANNELS:
- embed.colour = Colour.red()
+ embed.colour = discord.Colour.red()
embed.title = random.choice(NEGATIVE_REPLIES)
embed.description = "Sorry, you can't do that here!"
@@ -172,7 +192,7 @@ class Reminders(Scheduler, Cog):
# Let's limit this, so we don't get 10 000
# reminders from kip or something like that :P
if len(active_reminders) > MAXIMUM_REMINDERS:
- embed.colour = Colour.red()
+ embed.colour = discord.Colour.red()
embed.title = random.choice(NEGATIVE_REPLIES)
embed.description = "You have too many active reminders!"
@@ -205,7 +225,7 @@ class Reminders(Scheduler, Cog):
self.schedule_task(loop, reminder["id"], reminder)
@remind_group.command(name="list")
- async def list_reminders(self, ctx: Context) -> Optional[Message]:
+ async def list_reminders(self, ctx: Context) -> t.Optional[discord.Message]:
"""View a paginated embed of all reminders for your user."""
# Get all the user's reminders from the database.
data = await self.bot.api_client.get(
@@ -238,8 +258,8 @@ class Reminders(Scheduler, Cog):
lines.append(text)
- embed = Embed()
- embed.colour = Colour.blurple()
+ embed = discord.Embed()
+ embed.colour = discord.Colour.blurple()
embed.title = f"Reminders for {ctx.author}"
# Remind the user that they have no reminders :^)
@@ -248,7 +268,7 @@ class Reminders(Scheduler, Cog):
return await ctx.send(embed=embed)
# Construct the embed and paginate it.
- embed.colour = Colour.blurple()
+ embed.colour = discord.Colour.blurple()
await LinePaginator.paginate(
lines,
diff --git a/bot/cogs/tags.py b/bot/cogs/tags.py
index 54a51921c..b6360dfae 100644
--- a/bot/cogs/tags.py
+++ b/bot/cogs/tags.py
@@ -116,8 +116,10 @@ class Tags(Cog):
if _command_on_cooldown(tag_name):
time_left = Cooldowns.tags - (time.time() - self.tag_cooldowns[tag_name]["time"])
- log.warning(f"{ctx.author} tried to get the '{tag_name}' tag, but the tag is on cooldown. "
- f"Cooldown ends in {time_left:.1f} seconds.")
+ log.info(
+ f"{ctx.author} tried to get the '{tag_name}' tag, but the tag is on cooldown. "
+ f"Cooldown ends in {time_left:.1f} seconds."
+ )
return
await self._get_tags()
diff --git a/bot/cogs/verification.py b/bot/cogs/verification.py
index 988e0d49a..582237374 100644
--- a/bot/cogs/verification.py
+++ b/bot/cogs/verification.py
@@ -1,7 +1,8 @@
import logging
+from contextlib import suppress
from datetime import datetime
-from discord import Colour, Message, NotFound, Object
+from discord import Colour, Forbidden, Message, NotFound, Object
from discord.ext import tasks
from discord.ext.commands import Cog, Context, command
@@ -92,19 +93,21 @@ class Verification(Cog):
ping_everyone=Filter.ping_everyone,
)
- ctx = await self.bot.get_context(message) # type: Context
-
+ ctx: Context = await self.bot.get_context(message)
if ctx.command is not None and ctx.command.name == "accept":
- return # They used the accept command
+ return
- for role in ctx.author.roles:
- if role.id == Roles.verified:
- log.warning(f"{ctx.author} posted '{ctx.message.content}' "
- "in the verification channel, but is already verified.")
- return # They're already verified
+ if any(r.id == Roles.verified for r in ctx.author.roles):
+ log.info(
+ f"{ctx.author} posted '{ctx.message.content}' "
+ "in the verification channel, but is already verified."
+ )
+ return
- log.debug(f"{ctx.author} posted '{ctx.message.content}' in the verification "
- "channel. We are providing instructions how to verify.")
+ log.debug(
+ f"{ctx.author} posted '{ctx.message.content}' in the verification "
+ "channel. We are providing instructions how to verify."
+ )
await ctx.send(
f"{ctx.author.mention} Please type `!accept` to verify that you accept our rules, "
f"and gain access to the rest of the server.",
@@ -112,11 +115,8 @@ class Verification(Cog):
)
log.trace(f"Deleting the message posted by {ctx.author}")
-
- try:
+ with suppress(NotFound):
await ctx.message.delete()
- except NotFound:
- log.trace("No message found, it must have been deleted by another bot.")
@command(name='accept', aliases=('verify', 'verified', 'accepted'), hidden=True)
@without_role(Roles.verified)
@@ -127,17 +127,13 @@ class Verification(Cog):
await ctx.author.add_roles(Object(Roles.verified), reason="Accepted the rules")
try:
await ctx.author.send(WELCOME_MESSAGE)
- except Exception:
- # Catch the exception, in case they have DMs off or something
- log.exception(f"Unable to send welcome message to user {ctx.author}.")
-
- log.trace(f"Deleting the message posted by {ctx.author}.")
-
- try:
- self.mod_log.ignore(Event.message_delete, ctx.message.id)
- await ctx.message.delete()
- except NotFound:
- log.trace("No message found, it must have been deleted by another bot.")
+ except Forbidden:
+ log.info(f"Sending welcome message failed for {ctx.author}.")
+ finally:
+ log.trace(f"Deleting accept message by {ctx.author}.")
+ with suppress(NotFound):
+ self.mod_log.ignore(Event.message_delete, ctx.message.id)
+ await ctx.message.delete()
@command(name='subscribe')
@in_channel(Channels.bot)
diff --git a/bot/constants.py b/bot/constants.py
index e9990307a..681d8da49 100644
--- a/bot/constants.py
+++ b/bot/constants.py
@@ -193,7 +193,7 @@ class Bot(metaclass=YAMLGetter):
prefix: str
token: str
-
+ sentry_dsn: str
class Filter(metaclass=YAMLGetter):
section = "filter"
diff --git a/bot/pagination.py b/bot/pagination.py
index e82763912..90c8f849c 100644
--- a/bot/pagination.py
+++ b/bot/pagination.py
@@ -1,8 +1,9 @@
import asyncio
import logging
-from typing import Iterable, List, Optional, Tuple
+import typing as t
+from contextlib import suppress
-from discord import Embed, Member, Message, Reaction
+import discord
from discord.abc import User
from discord.ext.commands import Context, Paginator
@@ -14,7 +15,7 @@ RIGHT_EMOJI = "\u27A1" # [:arrow_right:]
LAST_EMOJI = "\u23ED" # [:track_next:]
DELETE_EMOJI = constants.Emojis.trashcan # [:trashcan:]
-PAGINATION_EMOJI = [FIRST_EMOJI, LEFT_EMOJI, RIGHT_EMOJI, LAST_EMOJI, DELETE_EMOJI]
+PAGINATION_EMOJI = (FIRST_EMOJI, LEFT_EMOJI, RIGHT_EMOJI, LAST_EMOJI, DELETE_EMOJI)
log = logging.getLogger(__name__)
@@ -89,12 +90,12 @@ class LinePaginator(Paginator):
@classmethod
async def paginate(
cls,
- lines: Iterable[str],
+ lines: t.List[str],
ctx: Context,
- embed: Embed,
+ embed: discord.Embed,
prefix: str = "",
suffix: str = "",
- max_lines: Optional[int] = None,
+ max_lines: t.Optional[int] = None,
max_size: int = 500,
empty: bool = True,
restrict_to_user: User = None,
@@ -102,7 +103,7 @@ class LinePaginator(Paginator):
footer_text: str = None,
url: str = None,
exception_on_empty_embed: bool = False
- ) -> Optional[Message]:
+ ) -> t.Optional[discord.Message]:
"""
Use a paginator and set of reactions to provide pagination over a set of lines.
@@ -114,11 +115,11 @@ class LinePaginator(Paginator):
Pagination will also be removed automatically if no reaction is added for five minutes (300 seconds).
Example:
- >>> embed = Embed()
+ >>> embed = discord.Embed()
>>> embed.set_author(name="Some Operation", url=url, icon_url=icon)
- >>> await LinePaginator.paginate((line for line in lines), ctx, embed)
+ >>> await LinePaginator.paginate([line for line in lines], ctx, embed)
"""
- def event_check(reaction_: Reaction, user_: Member) -> bool:
+ def event_check(reaction_: discord.Reaction, user_: discord.Member) -> bool:
"""Make sure that this reaction is what we want to operate on."""
no_restrictions = (
# Pagination is not restricted
@@ -281,8 +282,9 @@ class LinePaginator(Paginator):
await message.edit(embed=embed)
- log.debug("Ending pagination and removing all reactions...")
- await message.clear_reactions()
+ log.debug("Ending pagination and clearing reactions.")
+ with suppress(discord.NotFound):
+ await message.clear_reactions()
class ImagePaginator(Paginator):
@@ -299,6 +301,7 @@ class ImagePaginator(Paginator):
self._current_page = [prefix]
self.images = []
self._pages = []
+ self._count = 0
def add_line(self, line: str = '', *, empty: bool = False) -> None:
"""Adds a line to each page."""
@@ -316,13 +319,13 @@ class ImagePaginator(Paginator):
@classmethod
async def paginate(
cls,
- pages: List[Tuple[str, str]],
- ctx: Context, embed: Embed,
+ pages: t.List[t.Tuple[str, str]],
+ ctx: Context, embed: discord.Embed,
prefix: str = "",
suffix: str = "",
timeout: int = 300,
exception_on_empty_embed: bool = False
- ) -> Optional[Message]:
+ ) -> t.Optional[discord.Message]:
"""
Use a paginator and set of reactions to provide pagination over a set of title/image pairs.
@@ -334,11 +337,11 @@ class ImagePaginator(Paginator):
Note: Pagination will be removed automatically if no reaction is added for five minutes (300 seconds).
Example:
- >>> embed = Embed()
+ >>> embed = discord.Embed()
>>> embed.set_author(name="Some Operation", url=url, icon_url=icon)
>>> await ImagePaginator.paginate(pages, ctx, embed)
"""
- def check_event(reaction_: Reaction, member: Member) -> bool:
+ def check_event(reaction_: discord.Reaction, member: discord.Member) -> bool:
"""Checks each reaction added, if it matches our conditions pass the wait_for."""
return all((
# Reaction is on the same message sent
@@ -445,5 +448,6 @@ class ImagePaginator(Paginator):
await message.edit(embed=embed)
- log.debug("Ending pagination and removing all reactions...")
- await message.clear_reactions()
+ log.debug("Ending pagination and clearing reactions.")
+ with suppress(discord.NotFound):
+ await message.clear_reactions()
diff --git a/bot/rules/attachments.py b/bot/rules/attachments.py
index 00bb2a949..8903c385c 100644
--- a/bot/rules/attachments.py
+++ b/bot/rules/attachments.py
@@ -19,7 +19,7 @@ async def apply(
if total_recent_attachments > config['max']:
return (
- f"sent {total_recent_attachments} attachments in {config['max']}s",
+ f"sent {total_recent_attachments} attachments in {config['interval']}s",
(last_message.author,),
relevant_messages
)
diff --git a/bot/utils/time.py b/bot/utils/time.py
index 7416f36e0..77060143c 100644
--- a/bot/utils/time.py
+++ b/bot/utils/time.py
@@ -114,30 +114,40 @@ def format_infraction(timestamp: str) -> str:
def format_infraction_with_duration(
- expiry: Optional[str],
+ date_to: Optional[str],
date_from: Optional[datetime.datetime] = None,
- max_units: int = 2
+ max_units: int = 2,
+ absolute: bool = True
) -> Optional[str]:
"""
- Format an infraction timestamp to a more readable ISO 8601 format WITH the duration.
+ Return `date_to` formatted as a readable ISO-8601 with the humanized duration since `date_from`.
- Returns a human-readable version of the duration between datetime.utcnow() and an expiry.
- Unlike `humanize_delta`, this function will force the `precision` to be `seconds` by not passing it.
- `max_units` specifies the maximum number of units of time to include (e.g. 1 may include days but not hours).
- By default, max_units is 2.
+ `date_from` must be an ISO-8601 formatted timestamp. The duration is calculated as from
+ `date_from` until `date_to` with a precision of seconds. If `date_from` is unspecified, the
+ current time is used.
+
+ `max_units` specifies the maximum number of units of time to include in the duration. For
+ example, a value of 1 may include days but not hours.
+
+ If `absolute` is True, the absolute value of the duration delta is used. This prevents negative
+ values in the case that `date_to` is in the past relative to `date_from`.
"""
- if not expiry:
+ if not date_to:
return None
+ date_to_formatted = format_infraction(date_to)
+
date_from = date_from or datetime.datetime.utcnow()
- date_to = dateutil.parser.isoparse(expiry).replace(tzinfo=None, microsecond=0)
+ date_to = dateutil.parser.isoparse(date_to).replace(tzinfo=None, microsecond=0)
- expiry_formatted = format_infraction(expiry)
+ delta = relativedelta(date_to, date_from)
+ if absolute:
+ delta = abs(delta)
- duration = humanize_delta(relativedelta(date_to, date_from), max_units=max_units)
- duration_formatted = f" ({duration})" if duration else ''
+ duration = humanize_delta(delta, max_units=max_units)
+ duration_formatted = f" ({duration})" if duration else ""
- return f"{expiry_formatted}{duration_formatted}"
+ return f"{date_to_formatted}{duration_formatted}"
def until_expiration(
diff --git a/config-default.yml b/config-default.yml
index 3de7c6ba4..379475907 100644
--- a/config-default.yml
+++ b/config-default.yml
@@ -1,6 +1,7 @@
bot:
prefix: "!"
token: !ENV "BOT_TOKEN"
+ sentry_dsn: !ENV "BOT_SENTRY_DSN"
cooldowns:
# Per channel, per tag.
@@ -218,6 +219,7 @@ filter:
- 438622377094414346 # Pyglet
- 524691714909274162 # Panda3D
- 336642139381301249 # discord.py
+ - 405403391410438165 # Sentdex
domain_blacklist:
- pornhub.com
@@ -304,7 +306,7 @@ urls:
paste_service: !JOIN [*SCHEMA, *PASTE, "/{key}"]
# Snekbox
- snekbox_eval_api: "https://snekbox.pythondiscord.com/eval"
+ snekbox_eval_api: "http://snekbox:8060/eval"
# Discord API URLs
discord_api: &DISCORD_API "https://discordapp.com/api/v7/"
diff --git a/docker-compose.yml b/docker-compose.yml
index 7281c7953..11deceae8 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -23,6 +23,7 @@ services:
- staff.web
ports:
- "127.0.0.1:8000:8000"
+ tty: true
depends_on:
- postgres
environment:
@@ -37,6 +38,7 @@ services:
volumes:
- ./logs:/bot/logs
- .:/bot:ro
+ tty: true
depends_on:
- web
environment:
diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py
index 4496a2ae0..deae7ebad 100644
--- a/tests/bot/cogs/test_information.py
+++ b/tests/bot/cogs/test_information.py
@@ -125,10 +125,10 @@ class InformationCogTests(unittest.TestCase):
)
],
members=[
- *(helpers.MockMember(status='online') for _ in range(2)),
- *(helpers.MockMember(status='idle') for _ in range(1)),
- *(helpers.MockMember(status='dnd') for _ in range(4)),
- *(helpers.MockMember(status='offline') for _ in range(3)),
+ *(helpers.MockMember(status=discord.Status.online) for _ in range(2)),
+ *(helpers.MockMember(status=discord.Status.idle) for _ in range(1)),
+ *(helpers.MockMember(status=discord.Status.dnd) for _ in range(4)),
+ *(helpers.MockMember(status=discord.Status.offline) for _ in range(3)),
],
member_count=1_234,
icon_url='a-lemon.jpg',
@@ -153,9 +153,9 @@ class InformationCogTests(unittest.TestCase):
**Counts**
Members: {self.ctx.guild.member_count:,}
Roles: {len(self.ctx.guild.roles)}
- Text: 1
- Voice: 1
- Channel categories: 1
+ Category channels: 1
+ Text channels: 1
+ Voice channels: 1
**Members**
{constants.Emojis.status_online} 2
diff --git a/tests/bot/rules/__init__.py b/tests/bot/rules/__init__.py
index e69de29bb..36c986fe1 100644
--- a/tests/bot/rules/__init__.py
+++ b/tests/bot/rules/__init__.py
@@ -0,0 +1,76 @@
+import unittest
+from abc import ABCMeta, abstractmethod
+from typing import Callable, Dict, Iterable, List, NamedTuple, Tuple
+
+from tests.helpers import MockMessage
+
+
+class DisallowedCase(NamedTuple):
+ """Encapsulation for test cases expected to fail."""
+ recent_messages: List[MockMessage]
+ culprits: Iterable[str]
+ n_violations: int
+
+
+class RuleTest(unittest.TestCase, metaclass=ABCMeta):
+ """
+ Abstract class for antispam rule test cases.
+
+ Tests for specific rules should inherit from `RuleTest` and implement
+ `relevant_messages` and `get_report`. Each instance should also set the
+ `apply` and `config` attributes as necessary.
+
+ The execution of test cases can then be delegated to the `run_allowed`
+ and `run_disallowed` methods.
+ """
+
+ apply: Callable # The tested rule's apply function
+ config: Dict[str, int]
+
+ async def run_allowed(self, cases: Tuple[List[MockMessage], ...]) -> None:
+ """Run all `cases` against `self.apply` expecting them to pass."""
+ for recent_messages in cases:
+ last_message = recent_messages[0]
+
+ with self.subTest(
+ last_message=last_message,
+ recent_messages=recent_messages,
+ config=self.config,
+ ):
+ self.assertIsNone(
+ await self.apply(last_message, recent_messages, self.config)
+ )
+
+ async def run_disallowed(self, cases: Tuple[DisallowedCase, ...]) -> None:
+ """Run all `cases` against `self.apply` expecting them to fail."""
+ for case in cases:
+ recent_messages, culprits, n_violations = case
+ last_message = recent_messages[0]
+ relevant_messages = self.relevant_messages(case)
+ desired_output = (
+ self.get_report(case),
+ culprits,
+ relevant_messages,
+ )
+
+ with self.subTest(
+ last_message=last_message,
+ recent_messages=recent_messages,
+ relevant_messages=relevant_messages,
+ n_violations=n_violations,
+ config=self.config,
+ ):
+ self.assertTupleEqual(
+ await self.apply(last_message, recent_messages, self.config),
+ desired_output,
+ )
+
+ @abstractmethod
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ """Give expected relevant messages for `case`."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def get_report(self, case: DisallowedCase) -> str:
+ """Give expected error report for `case`."""
+ raise NotImplementedError
diff --git a/tests/bot/rules/test_attachments.py b/tests/bot/rules/test_attachments.py
index d7187f315..e54b4b5b8 100644
--- a/tests/bot/rules/test_attachments.py
+++ b/tests/bot/rules/test_attachments.py
@@ -1,98 +1,71 @@
-import unittest
-from typing import List, NamedTuple, Tuple
+from typing import Iterable
from bot.rules import attachments
+from tests.bot.rules import DisallowedCase, RuleTest
from tests.helpers import MockMessage, async_test
-class Case(NamedTuple):
- recent_messages: List[MockMessage]
- culprit: Tuple[str]
- total_attachments: int
-
-
-def msg(author: str, total_attachments: int) -> MockMessage:
+def make_msg(author: str, total_attachments: int) -> MockMessage:
"""Builds a message with `total_attachments` attachments."""
return MockMessage(author=author, attachments=list(range(total_attachments)))
-class AttachmentRuleTests(unittest.TestCase):
+class AttachmentRuleTests(RuleTest):
"""Tests applying the `attachments` antispam rule."""
def setUp(self):
- self.config = {"max": 5}
+ self.apply = attachments.apply
+ self.config = {"max": 5, "interval": 10}
@async_test
async def test_allows_messages_without_too_many_attachments(self):
"""Messages without too many attachments are allowed as-is."""
cases = (
- [msg("bob", 0), msg("bob", 0), msg("bob", 0)],
- [msg("bob", 2), msg("bob", 2)],
- [msg("bob", 2), msg("alice", 2), msg("bob", 2)],
+ [make_msg("bob", 0), make_msg("bob", 0), make_msg("bob", 0)],
+ [make_msg("bob", 2), make_msg("bob", 2)],
+ [make_msg("bob", 2), make_msg("alice", 2), make_msg("bob", 2)],
)
- for recent_messages in cases:
- last_message = recent_messages[0]
-
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- config=self.config
- ):
- self.assertIsNone(
- await attachments.apply(last_message, recent_messages, self.config)
- )
+ await self.run_allowed(cases)
@async_test
async def test_disallows_messages_with_too_many_attachments(self):
"""Messages with too many attachments trigger the rule."""
cases = (
- Case(
- [msg("bob", 4), msg("bob", 0), msg("bob", 6)],
+ DisallowedCase(
+ [make_msg("bob", 4), make_msg("bob", 0), make_msg("bob", 6)],
("bob",),
- 10
+ 10,
),
- Case(
- [msg("bob", 4), msg("alice", 6), msg("bob", 2)],
+ DisallowedCase(
+ [make_msg("bob", 4), make_msg("alice", 6), make_msg("bob", 2)],
("bob",),
- 6
+ 6,
),
- Case(
- [msg("alice", 6)],
+ DisallowedCase(
+ [make_msg("alice", 6)],
("alice",),
- 6
+ 6,
),
- (
- [msg("alice", 1) for _ in range(6)],
+ DisallowedCase(
+ [make_msg("alice", 1) for _ in range(6)],
("alice",),
- 6
+ 6,
),
)
- for recent_messages, culprit, total_attachments in cases:
- last_message = recent_messages[0]
- relevant_messages = tuple(
- msg
- for msg in recent_messages
- if (
- msg.author == last_message.author
- and len(msg.attachments) > 0
- )
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_message = case.recent_messages[0]
+ return tuple(
+ msg
+ for msg in case.recent_messages
+ if (
+ msg.author == last_message.author
+ and len(msg.attachments) > 0
)
+ )
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- relevant_messages=relevant_messages,
- total_attachments=total_attachments,
- config=self.config
- ):
- desired_output = (
- f"sent {total_attachments} attachments in {self.config['max']}s",
- culprit,
- relevant_messages
- )
- self.assertTupleEqual(
- await attachments.apply(last_message, recent_messages, self.config),
- desired_output
- )
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} attachments in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_burst.py b/tests/bot/rules/test_burst.py
new file mode 100644
index 000000000..72f0be0c7
--- /dev/null
+++ b/tests/bot/rules/test_burst.py
@@ -0,0 +1,56 @@
+from typing import Iterable
+
+from bot.rules import burst
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage, async_test
+
+
+def make_msg(author: str) -> MockMessage:
+ """
+ Init a MockMessage instance with author set to `author`.
+
+ This serves as a shorthand / alias to keep the test cases visually clean.
+ """
+ return MockMessage(author=author)
+
+
+class BurstRuleTests(RuleTest):
+ """Tests the `burst` antispam rule."""
+
+ def setUp(self):
+ self.apply = burst.apply
+ self.config = {"max": 2, "interval": 10}
+
+ @async_test
+ async def test_allows_messages_within_limit(self):
+ """Cases which do not violate the rule."""
+ cases = (
+ [make_msg("bob"), make_msg("bob")],
+ [make_msg("bob"), make_msg("alice"), make_msg("bob")],
+ )
+
+ await self.run_allowed(cases)
+
+ @async_test
+ async def test_disallows_messages_beyond_limit(self):
+ """Cases where the amount of messages exceeds the limit, triggering the rule."""
+ cases = (
+ DisallowedCase(
+ [make_msg("bob"), make_msg("bob"), make_msg("bob")],
+ ("bob",),
+ 3,
+ ),
+ DisallowedCase(
+ [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")],
+ ("bob",),
+ 3,
+ ),
+ )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ return tuple(msg for msg in case.recent_messages if msg.author in case.culprits)
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} messages in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_burst_shared.py b/tests/bot/rules/test_burst_shared.py
new file mode 100644
index 000000000..47367a5f8
--- /dev/null
+++ b/tests/bot/rules/test_burst_shared.py
@@ -0,0 +1,59 @@
+from typing import Iterable
+
+from bot.rules import burst_shared
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage, async_test
+
+
+def make_msg(author: str) -> MockMessage:
+ """
+ Init a MockMessage instance with the passed arg.
+
+ This serves as a shorthand / alias to keep the test cases visually clean.
+ """
+ return MockMessage(author=author)
+
+
+class BurstSharedRuleTests(RuleTest):
+ """Tests the `burst_shared` antispam rule."""
+
+ def setUp(self):
+ self.apply = burst_shared.apply
+ self.config = {"max": 2, "interval": 10}
+
+ @async_test
+ async def test_allows_messages_within_limit(self):
+ """
+ Cases that do not violate the rule.
+
+ There really isn't more to test here than a single case.
+ """
+ cases = (
+ [make_msg("spongebob"), make_msg("patrick")],
+ )
+
+ await self.run_allowed(cases)
+
+ @async_test
+ async def test_disallows_messages_beyond_limit(self):
+ """Cases where the amount of messages exceeds the limit, triggering the rule."""
+ cases = (
+ DisallowedCase(
+ [make_msg("bob"), make_msg("bob"), make_msg("bob")],
+ {"bob"},
+ 3,
+ ),
+ DisallowedCase(
+ [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")],
+ {"bob", "alice"},
+ 4,
+ ),
+ )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ return case.recent_messages
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} messages in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_chars.py b/tests/bot/rules/test_chars.py
new file mode 100644
index 000000000..7cc36f49e
--- /dev/null
+++ b/tests/bot/rules/test_chars.py
@@ -0,0 +1,66 @@
+from typing import Iterable
+
+from bot.rules import chars
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage, async_test
+
+
+def make_msg(author: str, n_chars: int) -> MockMessage:
+ """Build a message with arbitrary content of `n_chars` length."""
+ return MockMessage(author=author, content="A" * n_chars)
+
+
+class CharsRuleTests(RuleTest):
+ """Tests the `chars` antispam rule."""
+
+ def setUp(self):
+ self.apply = chars.apply
+ self.config = {
+ "max": 20, # Max allowed sum of chars per user
+ "interval": 10,
+ }
+
+ @async_test
+ async def test_allows_messages_within_limit(self):
+ """Cases with a total amount of chars within limit."""
+ cases = (
+ [make_msg("bob", 0)],
+ [make_msg("bob", 20)],
+ [make_msg("bob", 15), make_msg("alice", 15)],
+ )
+
+ await self.run_allowed(cases)
+
+ @async_test
+ async def test_disallows_messages_beyond_limit(self):
+ """Cases where the total amount of chars exceeds the limit, triggering the rule."""
+ cases = (
+ DisallowedCase(
+ [make_msg("bob", 21)],
+ ("bob",),
+ 21,
+ ),
+ DisallowedCase(
+ [make_msg("bob", 15), make_msg("bob", 15)],
+ ("bob",),
+ 30,
+ ),
+ DisallowedCase(
+ [make_msg("alice", 15), make_msg("bob", 20), make_msg("alice", 15)],
+ ("alice",),
+ 30,
+ ),
+ )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_message = case.recent_messages[0]
+ return tuple(
+ msg
+ for msg in case.recent_messages
+ if msg.author == last_message.author
+ )
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} characters in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_discord_emojis.py b/tests/bot/rules/test_discord_emojis.py
new file mode 100644
index 000000000..0239b0b00
--- /dev/null
+++ b/tests/bot/rules/test_discord_emojis.py
@@ -0,0 +1,54 @@
+from typing import Iterable
+
+from bot.rules import discord_emojis
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage, async_test
+
+discord_emoji = "<:abcd:1234>" # Discord emojis follow the format <:name:id>
+
+
+def make_msg(author: str, n_emojis: int) -> MockMessage:
+ """Build a MockMessage instance with content containing `n_emojis` arbitrary emojis."""
+ return MockMessage(author=author, content=discord_emoji * n_emojis)
+
+
+class DiscordEmojisRuleTests(RuleTest):
+ """Tests for the `discord_emojis` antispam rule."""
+
+ def setUp(self):
+ self.apply = discord_emojis.apply
+ self.config = {"max": 2, "interval": 10}
+
+ @async_test
+ async def test_allows_messages_within_limit(self):
+ """Cases with a total amount of discord emojis within limit."""
+ cases = (
+ [make_msg("bob", 2)],
+ [make_msg("alice", 1), make_msg("bob", 2), make_msg("alice", 1)],
+ )
+
+ await self.run_allowed(cases)
+
+ @async_test
+ async def test_disallows_messages_beyond_limit(self):
+ """Cases with more than the allowed amount of discord emojis."""
+ cases = (
+ DisallowedCase(
+ [make_msg("bob", 3)],
+ ("bob",),
+ 3,
+ ),
+ DisallowedCase(
+ [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)],
+ ("alice",),
+ 4,
+ ),
+ )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ return tuple(msg for msg in case.recent_messages if msg.author in case.culprits)
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} emojis in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_duplicates.py b/tests/bot/rules/test_duplicates.py
new file mode 100644
index 000000000..59e0fb6ef
--- /dev/null
+++ b/tests/bot/rules/test_duplicates.py
@@ -0,0 +1,66 @@
+from typing import Iterable
+
+from bot.rules import duplicates
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage, async_test
+
+
+def make_msg(author: str, content: str) -> MockMessage:
+ """Give a MockMessage instance with `author` and `content` attrs."""
+ return MockMessage(author=author, content=content)
+
+
+class DuplicatesRuleTests(RuleTest):
+ """Tests the `duplicates` antispam rule."""
+
+ def setUp(self):
+ self.apply = duplicates.apply
+ self.config = {"max": 2, "interval": 10}
+
+ @async_test
+ async def test_allows_messages_within_limit(self):
+ """Cases which do not violate the rule."""
+ cases = (
+ [make_msg("alice", "A"), make_msg("alice", "A")],
+ [make_msg("alice", "A"), make_msg("alice", "B"), make_msg("alice", "C")], # Non-duplicate
+ [make_msg("alice", "A"), make_msg("bob", "A"), make_msg("alice", "A")], # Different author
+ )
+
+ await self.run_allowed(cases)
+
+ @async_test
+ async def test_disallows_messages_beyond_limit(self):
+ """Cases with too many duplicate messages from the same author."""
+ cases = (
+ DisallowedCase(
+ [make_msg("alice", "A"), make_msg("alice", "A"), make_msg("alice", "A")],
+ ("alice",),
+ 3,
+ ),
+ DisallowedCase(
+ [make_msg("bob", "A"), make_msg("alice", "A"), make_msg("bob", "A"), make_msg("bob", "A")],
+ ("bob",),
+ 3, # 4 duplicate messages, but only 3 from bob
+ ),
+ DisallowedCase(
+ [make_msg("bob", "A"), make_msg("bob", "B"), make_msg("bob", "A"), make_msg("bob", "A")],
+ ("bob",),
+ 3, # 4 message from bob, but only 3 duplicates
+ ),
+ )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_message = case.recent_messages[0]
+ return tuple(
+ msg
+ for msg in case.recent_messages
+ if (
+ msg.author == last_message.author
+ and msg.content == last_message.content
+ )
+ )
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} duplicated messages in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_links.py b/tests/bot/rules/test_links.py
index 02a5d5501..3c3f90e5f 100644
--- a/tests/bot/rules/test_links.py
+++ b/tests/bot/rules/test_links.py
@@ -1,26 +1,21 @@
-import unittest
-from typing import List, NamedTuple, Tuple
+from typing import Iterable
from bot.rules import links
+from tests.bot.rules import DisallowedCase, RuleTest
from tests.helpers import MockMessage, async_test
-class Case(NamedTuple):
- recent_messages: List[MockMessage]
- culprit: Tuple[str]
- total_links: int
-
-
-def msg(author: str, total_links: int) -> MockMessage:
+def make_msg(author: str, total_links: int) -> MockMessage:
"""Makes a message with `total_links` links."""
content = " ".join(["https://pydis.com"] * total_links)
return MockMessage(author=author, content=content)
-class LinksTests(unittest.TestCase):
+class LinksTests(RuleTest):
"""Tests applying the `links` rule."""
def setUp(self):
+ self.apply = links.apply
self.config = {
"max": 2,
"interval": 10
@@ -30,68 +25,45 @@ class LinksTests(unittest.TestCase):
async def test_links_within_limit(self):
"""Messages with an allowed amount of links."""
cases = (
- [msg("bob", 0)],
- [msg("bob", 2)],
- [msg("bob", 3)], # Filter only applies if len(messages_with_links) > 1
- [msg("bob", 1), msg("bob", 1)],
- [msg("bob", 2), msg("alice", 2)] # Only messages from latest author count
+ [make_msg("bob", 0)],
+ [make_msg("bob", 2)],
+ [make_msg("bob", 3)], # Filter only applies if len(messages_with_links) > 1
+ [make_msg("bob", 1), make_msg("bob", 1)],
+ [make_msg("bob", 2), make_msg("alice", 2)] # Only messages from latest author count
)
- for recent_messages in cases:
- last_message = recent_messages[0]
-
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- config=self.config
- ):
- self.assertIsNone(
- await links.apply(last_message, recent_messages, self.config)
- )
+ await self.run_allowed(cases)
@async_test
async def test_links_exceeding_limit(self):
"""Messages with a a higher than allowed amount of links."""
cases = (
- Case(
- [msg("bob", 1), msg("bob", 2)],
+ DisallowedCase(
+ [make_msg("bob", 1), make_msg("bob", 2)],
("bob",),
3
),
- Case(
- [msg("alice", 1), msg("alice", 1), msg("alice", 1)],
+ DisallowedCase(
+ [make_msg("alice", 1), make_msg("alice", 1), make_msg("alice", 1)],
("alice",),
3
),
- Case(
- [msg("alice", 2), msg("bob", 3), msg("alice", 1)],
+ DisallowedCase(
+ [make_msg("alice", 2), make_msg("bob", 3), make_msg("alice", 1)],
("alice",),
3
)
)
- for recent_messages, culprit, total_links in cases:
- last_message = recent_messages[0]
- relevant_messages = tuple(
- msg
- for msg in recent_messages
- if msg.author == last_message.author
- )
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_message = case.recent_messages[0]
+ return tuple(
+ msg
+ for msg in case.recent_messages
+ if msg.author == last_message.author
+ )
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- relevant_messages=relevant_messages,
- culprit=culprit,
- total_links=total_links,
- config=self.config
- ):
- desired_output = (
- f"sent {total_links} links in {self.config['interval']}s",
- culprit,
- relevant_messages
- )
- self.assertTupleEqual(
- await links.apply(last_message, recent_messages, self.config),
- desired_output
- )
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} links in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py
index ad49ead32..ebcdabac6 100644
--- a/tests/bot/rules/test_mentions.py
+++ b/tests/bot/rules/test_mentions.py
@@ -1,95 +1,67 @@
-import unittest
-from typing import List, NamedTuple, Tuple
+from typing import Iterable
from bot.rules import mentions
+from tests.bot.rules import DisallowedCase, RuleTest
from tests.helpers import MockMessage, async_test
-class Case(NamedTuple):
- recent_messages: List[MockMessage]
- culprit: Tuple[str]
- total_mentions: int
-
-
-def msg(author: str, total_mentions: int) -> MockMessage:
+def make_msg(author: str, total_mentions: int) -> MockMessage:
"""Makes a message with `total_mentions` mentions."""
return MockMessage(author=author, mentions=list(range(total_mentions)))
-class TestMentions(unittest.TestCase):
+class TestMentions(RuleTest):
"""Tests applying the `mentions` antispam rule."""
def setUp(self):
+ self.apply = mentions.apply
self.config = {
"max": 2,
- "interval": 10
+ "interval": 10,
}
@async_test
async def test_mentions_within_limit(self):
"""Messages with an allowed amount of mentions."""
cases = (
- [msg("bob", 0)],
- [msg("bob", 2)],
- [msg("bob", 1), msg("bob", 1)],
- [msg("bob", 1), msg("alice", 2)]
+ [make_msg("bob", 0)],
+ [make_msg("bob", 2)],
+ [make_msg("bob", 1), make_msg("bob", 1)],
+ [make_msg("bob", 1), make_msg("alice", 2)],
)
- for recent_messages in cases:
- last_message = recent_messages[0]
-
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- config=self.config
- ):
- self.assertIsNone(
- await mentions.apply(last_message, recent_messages, self.config)
- )
+ await self.run_allowed(cases)
@async_test
async def test_mentions_exceeding_limit(self):
"""Messages with a higher than allowed amount of mentions."""
cases = (
- Case(
- [msg("bob", 3)],
+ DisallowedCase(
+ [make_msg("bob", 3)],
("bob",),
- 3
+ 3,
),
- Case(
- [msg("alice", 2), msg("alice", 0), msg("alice", 1)],
+ DisallowedCase(
+ [make_msg("alice", 2), make_msg("alice", 0), make_msg("alice", 1)],
("alice",),
- 3
+ 3,
),
- Case(
- [msg("bob", 2), msg("alice", 3), msg("bob", 2)],
+ DisallowedCase(
+ [make_msg("bob", 2), make_msg("alice", 3), make_msg("bob", 2)],
("bob",),
- 4
+ 4,
)
)
- for recent_messages, culprit, total_mentions in cases:
- last_message = recent_messages[0]
- relevant_messages = tuple(
- msg
- for msg in recent_messages
- if msg.author == last_message.author
- )
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_message = case.recent_messages[0]
+ return tuple(
+ msg
+ for msg in case.recent_messages
+ if msg.author == last_message.author
+ )
- with self.subTest(
- last_message=last_message,
- recent_messages=recent_messages,
- relevant_messages=relevant_messages,
- culprit=culprit,
- total_mentions=total_mentions,
- cofig=self.config
- ):
- desired_output = (
- f"sent {total_mentions} mentions in {self.config['interval']}s",
- culprit,
- relevant_messages
- )
- self.assertTupleEqual(
- await mentions.apply(last_message, recent_messages, self.config),
- desired_output
- )
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} mentions in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_newlines.py b/tests/bot/rules/test_newlines.py
new file mode 100644
index 000000000..d61c4609d
--- /dev/null
+++ b/tests/bot/rules/test_newlines.py
@@ -0,0 +1,105 @@
+from typing import Iterable, List
+
+from bot.rules import newlines
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage, async_test
+
+
+def make_msg(author: str, newline_groups: List[int]) -> MockMessage:
+ """Init a MockMessage instance with `author` and content configured by `newline_groups".
+
+ Configure content by passing a list of ints, where each int `n` will generate
+ a separate group of `n` newlines.
+
+ Example:
+ newline_groups=[3, 1, 2] -> content="\n\n\n \n \n\n"
+ """
+ content = " ".join("\n" * n for n in newline_groups)
+ return MockMessage(author=author, content=content)
+
+
+class TotalNewlinesRuleTests(RuleTest):
+ """Tests the `newlines` antispam rule against allowed cases and total newline count violations."""
+
+ def setUp(self):
+ self.apply = newlines.apply
+ self.config = {
+ "max": 5, # Max sum of newlines in relevant messages
+ "max_consecutive": 3, # Max newlines in one group, in one message
+ "interval": 10,
+ }
+
+ @async_test
+ async def test_allows_messages_within_limit(self):
+ """Cases which do not violate the rule."""
+ cases = (
+ [make_msg("alice", [])], # Single message with no newlines
+ [make_msg("alice", [1, 2]), make_msg("alice", [1, 1])], # 5 newlines in 2 messages
+ [make_msg("alice", [2, 2, 1]), make_msg("bob", [2, 3])], # 5 newlines from each author
+ [make_msg("bob", [1]), make_msg("alice", [5])], # Alice breaks the rule, but only bob is relevant
+ )
+
+ await self.run_allowed(cases)
+
+ @async_test
+ async def test_disallows_messages_total(self):
+ """Cases which violate the rule by having too many newlines in total."""
+ cases = (
+ DisallowedCase( # Alice sends a total of 6 newlines (disallowed)
+ [make_msg("alice", [2, 2]), make_msg("alice", [2])],
+ ("alice",),
+ 6,
+ ),
+ DisallowedCase( # Here we test that only alice's newlines count in the sum
+ [make_msg("alice", [2, 2]), make_msg("bob", [3]), make_msg("alice", [3])],
+ ("alice",),
+ 7,
+ ),
+ )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_author = case.recent_messages[0].author
+ return tuple(msg for msg in case.recent_messages if msg.author == last_author)
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} newlines in {self.config['interval']}s"
+
+
+class GroupNewlinesRuleTests(RuleTest):
+ """
+ Tests the `newlines` antispam rule against max consecutive newline violations.
+
+ As these violations yield a different error report, they require a different
+ `get_report` implementation.
+ """
+
+ def setUp(self):
+ self.apply = newlines.apply
+ self.config = {"max": 5, "max_consecutive": 3, "interval": 10}
+
+ @async_test
+ async def test_disallows_messages_consecutive(self):
+ """Cases which violate the rule due to having too many consecutive newlines."""
+ cases = (
+ DisallowedCase( # Bob sends a group of newlines too large
+ [make_msg("bob", [4])],
+ ("bob",),
+ 4,
+ ),
+ DisallowedCase( # Alice sends 5 in total (allowed), but 4 in one group (disallowed)
+ [make_msg("alice", [1]), make_msg("alice", [4])],
+ ("alice",),
+ 4,
+ ),
+ )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_author = case.recent_messages[0].author
+ return tuple(msg for msg in case.recent_messages if msg.author == last_author)
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} consecutive newlines in {self.config['interval']}s"
diff --git a/tests/bot/rules/test_role_mentions.py b/tests/bot/rules/test_role_mentions.py
new file mode 100644
index 000000000..b339cccf7
--- /dev/null
+++ b/tests/bot/rules/test_role_mentions.py
@@ -0,0 +1,57 @@
+from typing import Iterable
+
+from bot.rules import role_mentions
+from tests.bot.rules import DisallowedCase, RuleTest
+from tests.helpers import MockMessage, async_test
+
+
+def make_msg(author: str, n_mentions: int) -> MockMessage:
+ """Build a MockMessage instance with `n_mentions` role mentions."""
+ return MockMessage(author=author, role_mentions=[None] * n_mentions)
+
+
+class RoleMentionsRuleTests(RuleTest):
+ """Tests for the `role_mentions` antispam rule."""
+
+ def setUp(self):
+ self.apply = role_mentions.apply
+ self.config = {"max": 2, "interval": 10}
+
+ @async_test
+ async def test_allows_messages_within_limit(self):
+ """Cases with a total amount of role mentions within limit."""
+ cases = (
+ [make_msg("bob", 2)],
+ [make_msg("bob", 1), make_msg("alice", 1), make_msg("bob", 1)],
+ )
+
+ await self.run_allowed(cases)
+
+ @async_test
+ async def test_disallows_messages_beyond_limit(self):
+ """Cases with more than the allowed amount of role mentions."""
+ cases = (
+ DisallowedCase(
+ [make_msg("bob", 3)],
+ ("bob",),
+ 3,
+ ),
+ DisallowedCase(
+ [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)],
+ ("alice",),
+ 4,
+ ),
+ )
+
+ await self.run_disallowed(cases)
+
+ def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]:
+ last_message = case.recent_messages[0]
+ return tuple(
+ msg
+ for msg in case.recent_messages
+ if msg.author == last_message.author
+ )
+
+ def get_report(self, case: DisallowedCase) -> str:
+ return f"sent {case.n_violations} role mentions in {self.config['interval']}s"
diff --git a/tests/bot/test_api.py b/tests/bot/test_api.py
index 5a88adc5c..bdfcc73e4 100644
--- a/tests/bot/test_api.py
+++ b/tests/bot/test_api.py
@@ -1,9 +1,7 @@
-import logging
import unittest
-from unittest.mock import MagicMock, patch
+from unittest.mock import MagicMock
from bot import api
-from tests.base import LoggingTestCase
from tests.helpers import async_test
@@ -34,7 +32,7 @@ class APIClientTests(unittest.TestCase):
self.assertEqual(error.response_text, "")
self.assertIs(error.response, self.error_api_response)
- def test_responde_code_error_string_representation_default_initialization(self):
+ def test_response_code_error_string_representation_default_initialization(self):
"""Test the string representation of `ResponseCodeError` initialized without text or json."""
error = api.ResponseCodeError(response=self.error_api_response)
self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: ")
@@ -76,61 +74,3 @@ class APIClientTests(unittest.TestCase):
response_text=text_data
)
self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: {text_data}")
-
-
-class LoggingHandlerTests(LoggingTestCase):
- """Tests the bot's API Log Handler."""
-
- @classmethod
- def setUpClass(cls):
- cls.debug_log_record = logging.LogRecord(
- name='my.logger', level=logging.DEBUG,
- pathname='my/logger.py', lineno=666,
- msg="Lemon wins", args=(),
- exc_info=None
- )
-
- cls.trace_log_record = logging.LogRecord(
- name='my.logger', level=logging.TRACE,
- pathname='my/logger.py', lineno=666,
- msg="This will not be logged", args=(),
- exc_info=None
- )
-
- def setUp(self):
- self.log_handler = api.APILoggingHandler(None)
-
- def test_emit_appends_to_queue_with_stopped_event_loop(self):
- """Test if `APILoggingHandler.emit` appends to queue when the event loop is not running."""
- with patch("bot.api.APILoggingHandler.ship_off") as ship_off:
- # Patch `ship_off` to ease testing against the return value of this coroutine.
- ship_off.return_value = 42
- self.log_handler.emit(self.debug_log_record)
-
- self.assertListEqual(self.log_handler.queue, [42])
-
- def test_emit_ignores_less_than_debug(self):
- """`APILoggingHandler.emit` should not queue logs with a log level lower than DEBUG."""
- self.log_handler.emit(self.trace_log_record)
- self.assertListEqual(self.log_handler.queue, [])
-
- def test_schedule_queued_tasks_for_empty_queue(self):
- """`APILoggingHandler` should not schedule anything when the queue is empty."""
- with self.assertNotLogs(level=logging.DEBUG):
- self.log_handler.schedule_queued_tasks()
-
- def test_schedule_queued_tasks_for_nonempty_queue(self):
- """`APILoggingHandler` should schedule logs when the queue is not empty."""
- log = logging.getLogger("bot.api")
-
- with self.assertLogs(logger=log, level=logging.DEBUG) as logs, patch('asyncio.create_task') as create_task:
- self.log_handler.queue = [555]
- self.log_handler.schedule_queued_tasks()
- self.assertListEqual(self.log_handler.queue, [])
- create_task.assert_called_once_with(555)
-
- [record] = logs.records
- self.assertEqual(record.message, "Scheduled 1 pending logging tasks.")
- self.assertEqual(record.levelno, logging.DEBUG)
- self.assertEqual(record.name, 'bot.api')
- self.assertIn('via_handler', record.__dict__)
diff --git a/tox.ini b/tox.ini
index d14819d57..b8293a3b6 100644
--- a/tox.ini
+++ b/tox.ini
@@ -3,7 +3,7 @@ max-line-length=120
docstring-convention=all
import-order-style=pycharm
application_import_names=bot,tests
-exclude=.cache,.venv,constants.py
+exclude=.cache,.venv,.git,constants.py
ignore=
B311,W503,E226,S311,T000
# Missing Docstrings
@@ -15,5 +15,5 @@ ignore=
# Docstring Content
D400,D401,D402,D404,D405,D406,D407,D408,D409,D410,D411,D412,D413,D414,D416,D417
# Type Annotations
- TYP002,TYP003,TYP101,TYP102,TYP204,TYP206
-per-file-ignores=tests/*:D,TYP
+ ANN002,ANN003,ANN101,ANN102,ANN204,ANN206
+per-file-ignores=tests/*:D,ANN