diff options
author | 2020-02-28 09:40:55 -0800 | |
---|---|---|
committer | 2020-02-28 09:40:55 -0800 | |
commit | b86fb16643d0026e8681d2163bd6ce86b092c57a (patch) | |
tree | b37065e33c88b42f5ce73dde98044777728b65f4 | |
parent | Ignore NotFound errors inside continue_eval (diff) | |
parent | Merge pull request #757 from python-discord/feat/backend/b131/error-handling (diff) |
Merge remote-tracking branch 'origin/master' into eval-enhancements
67 files changed, 3084 insertions, 1440 deletions
@@ -6,7 +6,6 @@ name = "pypi" [packages] discord-py = "~=1.3.1" aiodns = "~=2.0" -logmatic-python = "~=0.1" aiohttp = "~=3.5" sphinx = "~=2.2" markdownify = "~=0.4" @@ -19,11 +18,12 @@ deepdiff = "~=4.0" requests = "~=2.22" more_itertools = "~=7.2" urllib3 = ">=1.24.2,<1.25" +sentry-sdk = "~=0.14" [dev-packages] coverage = "~=4.5" flake8 = "~=3.7" -flake8-annotations = "~=1.1" +flake8-annotations = "~=2.0" flake8-bugbear = "~=19.8" flake8-docstrings = "~=1.4" flake8-import-order = "~=0.18" diff --git a/Pipfile.lock b/Pipfile.lock index bf8ff47e9..fa29bf995 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "0a0354a8cbd25b19c61b68f928493a445e737dc6447c97f4c4b52fbf72d887ac" + "sha256": "c7706a61eb96c06d073898018ea2dbcf5bd3b15d007496e2d60120a65647f31e" }, "pipfile-spec": 6, "requires": { @@ -18,11 +18,11 @@ "default": { "aio-pika": { "hashes": [ - "sha256:a5837277e53755078db3a9e8c45bbca605c8ba9ecba7a02d74a7a1779f444723", - "sha256:fa32e33b4b7d0804dcf439ae6ff24d2f0a83d1ba280ee9f555e647d71d394ff5" + "sha256:4199122a450dffd8303b7857a9d82657bf1487fe329e489520833b40fbe92406", + "sha256:fe85c7456e5c060bce4eb9cffab5b2c4d3c563cb72177977b3556c54c8e3aeb6" ], "index": "pypi", - "version": "==6.4.1" + "version": "==6.5.2" }, "aiodns": { "hashes": [ @@ -52,10 +52,10 @@ }, "aiormq": { "hashes": [ - "sha256:8c215a970133ab5ee7c478decac55b209af7731050f52d11439fe910fa0f9e9d", - "sha256:9210f3389200aee7d8067f6435f4a9eff2d3a30b88beb5eaae406ccc11c0fc01" + "sha256:286e0b0772075580466e45f98f051b9728a9316b9c36f0c14c7bc1409be375b0", + "sha256:7ed7d6df6b57af7f8bce7d1ebcbdfc32b676192e46703e81e9e217316e56b5bd" ], - "version": "==3.2.0" + "version": "==3.2.1" }, "alabaster": { "hashes": [ @@ -164,18 +164,18 @@ }, "fuzzywuzzy": { "hashes": [ - "sha256:5ac7c0b3f4658d2743aa17da53a55598144edbc5bee3c6863840636e6926f254", - "sha256:6f49de47db00e1c71d40ad16da42284ac357936fa9b66bea1df63fed07122d62" + "sha256:45016e92264780e58972dca1b3d939ac864b78437422beecebb3095f8efd00e8", + "sha256:928244b28db720d1e0ee7587acf660ea49d7e4c632569cad4f1cd7e68a5f0993" ], "index": "pypi", - "version": "==0.17.0" + "version": "==0.18.0" }, "idna": { "hashes": [ - "sha256:c357b3f628cf53ae2c4c05627ecc484553142ca23264e593d327bcde5e9c3407", - "sha256:ea8b7f6188e6fa117537c3df7da9fc686d485087abf6ac197f9c46432f7e4a3c" + "sha256:7588d1c14ae4c77d74036e8c22ff447b26d0fde8f007354fd48a7814db15b7cb", + "sha256:a068a21ceac8a4d63dbfd964670474107f541babbd2250d61922f029858365fa" ], - "version": "==2.8" + "version": "==2.9" }, "imagesize": { "hashes": [ @@ -191,13 +191,6 @@ ], "version": "==2.11.1" }, - "logmatic-python": { - "hashes": [ - "sha256:0c15ac9f5faa6a60059b28910db642c3dc7722948c3cc940923f8c9039604342" - ], - "index": "pypi", - "version": "==0.1.7" - }, "lxml": { "hashes": [ "sha256:06d4e0bbb1d62e38ae6118406d7cdb4693a3fa34ee3762238bcb96c9e36a93cd", @@ -388,12 +381,6 @@ "index": "pypi", "version": "==2.8.1" }, - "python-json-logger": { - "hashes": [ - "sha256:b7a31162f2a01965a5efb94453ce69230ed208468b0bbc7fdfc56e6d8df2e281" - ], - "version": "==0.1.11" - }, "pytz": { "hashes": [ "sha256:1c557d7d0e871de1f5ccd5833f60fb2550652da6be2693c1e02300743d21500d", @@ -420,11 +407,19 @@ }, "requests": { "hashes": [ - "sha256:11e007a8a2aa0323f5a921e9e6a2d7e4e67d9877e85773fba9ba6419025cbeb4", - "sha256:9cf5292fcd0f598c671cfc1e0d7d1a7f13bb8085e9a590f48c010551dc6c4b31" + "sha256:43999036bfa82904b6af1d99e4882b560e5e2c68e5c4b0aa03b655f3d7d73fee", + "sha256:b3f43d496c6daba4493e7c431722aeb7dbc6288f52a6e04e7b6023b0247817e6" + ], + "index": "pypi", + "version": "==2.23.0" + }, + "sentry-sdk": { + "hashes": [ + "sha256:b06dd27391fd11fb32f84fe054e6a64736c469514a718a99fb5ce1dff95d6b28", + "sha256:e023da07cfbead3868e1e2ba994160517885a32dfd994fc455b118e37989479b" ], "index": "pypi", - "version": "==2.22.0" + "version": "==0.14.1" }, "six": { "hashes": [ @@ -449,11 +444,11 @@ }, "sphinx": { "hashes": [ - "sha256:298537cb3234578b2d954ff18c5608468229e116a9757af3b831c2b2b4819159", - "sha256:e6e766b74f85f37a5f3e0773a1e1be8db3fcb799deb58ca6d18b70b0b44542a5" + "sha256:525527074f2e0c2585f68f73c99b4dc257c34bbe308b27f5f8c7a6e20642742f", + "sha256:543d39db5f82d83a5c1aa0c10c88f2b6cff2da3e711aa849b2c627b4b403bbd9" ], "index": "pypi", - "version": "==2.3.1" + "version": "==2.4.2" }, "sphinxcontrib-applehelp": { "hashes": [ @@ -556,6 +551,13 @@ } }, "develop": { + "appdirs": { + "hashes": [ + "sha256:9e5896d1372858f8dd3344faf4e5014d21849c756c8d5701f78f8a103b372d92", + "sha256:d8b24664561d0d34ddfaec54636d502d7cea6e29c3eaf68f3df6180863e2166e" + ], + "version": "==1.4.3" + }, "aspy.yaml": { "hashes": [ "sha256:463372c043f70160a9ec950c3f1e4c3a82db5fca01d334b6bc89c7164d744bdc", @@ -579,10 +581,10 @@ }, "cfgv": { "hashes": [ - "sha256:edb387943b665bf9c434f717bf630fa78aecd53d5900d2e05da6ad6048553144", - "sha256:fbd93c9ab0a523bf7daec408f3be2ed99a980e20b2d19b50fc184ca6b820d289" + "sha256:04b093b14ddf9fd4d17c53ebfd55582d27b76ed30050193c14e560770c5360eb", + "sha256:f22b426ed59cd2ab2b54ff96608d846c33dfb8766a67f0b4a6ce130ce244414f" ], - "version": "==2.0.1" + "version": "==3.0.0" }, "chardet": { "hashes": [ @@ -636,6 +638,12 @@ "index": "pypi", "version": "==4.5.4" }, + "distlib": { + "hashes": [ + "sha256:2e166e231a26b36d6dfe35a48c4464346620f8645ed0ace01ee31822b288de21" + ], + "version": "==0.3.0" + }, "dodgy": { "hashes": [ "sha256:28323cbfc9352139fdd3d316fa17f325cc0e9ac74438cbba51d70f9b48f86c3a", @@ -658,6 +666,13 @@ ], "version": "==0.3" }, + "filelock": { + "hashes": [ + "sha256:18d82244ee114f543149c66a6e0c14e9c4f8a1044b5cdaadd0f82159d6a6ff59", + "sha256:929b7d63ec5b7d6b71b0fa5ac14e030b3f70b75747cef1b10da9b879fef15836" + ], + "version": "==3.0.12" + }, "flake8": { "hashes": [ "sha256:45681a117ecc81e870cbf1262835ae4af5e7a8b08e40b944a8a6e6b895914cfb", @@ -668,11 +683,11 @@ }, "flake8-annotations": { "hashes": [ - "sha256:05b85538014c850a86dce7374bb6621c64481c24e35e8e90af1315f4d7a3dbaa", - "sha256:43e5233a76fda002b91a54a7cc4510f099c4bfd6279502ec70164016250eebd1" + "sha256:19a6637a5da1bb7ea7948483ca9e2b9e15b213e687e7bf5ff8c1bfc91c185006", + "sha256:bb033b72cdd3a2b0a530bbdf2081f12fbea7d70baeaaebb5899723a45f424b8e" ], "index": "pypi", - "version": "==1.1.3" + "version": "==2.0.0" }, "flake8-bugbear": { "hashes": [ @@ -700,11 +715,11 @@ }, "flake8-string-format": { "hashes": [ - "sha256:68ea72a1a5b75e7018cae44d14f32473c798cf73d75cbaed86c6a9a907b770b2", - "sha256:774d56103d9242ed968897455ef49b7d6de272000cfa83de5814273a868832f1" + "sha256:65f3da786a1461ef77fca3780b314edb2853c377f2e35069723348c8917deaa2", + "sha256:812ff431f10576a74c89be4e85b8e075a705be39bc40c4b4278b5b13e2afa9af" ], "index": "pypi", - "version": "==0.2.3" + "version": "==0.3.0" }, "flake8-tidy-imports": { "hashes": [ @@ -730,10 +745,10 @@ }, "idna": { "hashes": [ - "sha256:c357b3f628cf53ae2c4c05627ecc484553142ca23264e593d327bcde5e9c3407", - "sha256:ea8b7f6188e6fa117537c3df7da9fc686d485087abf6ac197f9c46432f7e4a3c" + "sha256:7588d1c14ae4c77d74036e8c22ff447b26d0fde8f007354fd48a7814db15b7cb", + "sha256:a068a21ceac8a4d63dbfd964670474107f541babbd2250d61922f029858365fa" ], - "version": "==2.8" + "version": "==2.9" }, "importlib-metadata": { "hashes": [ @@ -818,11 +833,11 @@ }, "requests": { "hashes": [ - "sha256:11e007a8a2aa0323f5a921e9e6a2d7e4e67d9877e85773fba9ba6419025cbeb4", - "sha256:9cf5292fcd0f598c671cfc1e0d7d1a7f13bb8085e9a590f48c010551dc6c4b31" + "sha256:43999036bfa82904b6af1d99e4882b560e5e2c68e5c4b0aa03b655f3d7d73fee", + "sha256:b3f43d496c6daba4493e7c431722aeb7dbc6288f52a6e04e7b6023b0247817e6" ], "index": "pypi", - "version": "==2.22.0" + "version": "==2.23.0" }, "safety": { "hashes": [ @@ -898,17 +913,17 @@ }, "virtualenv": { "hashes": [ - "sha256:0d62c70883c0342d59c11d0ddac0d954d0431321a41ab20851facf2b222598f3", - "sha256:55059a7a676e4e19498f1aad09b8313a38fcc0cdbe4fdddc0e9b06946d21b4bb" + "sha256:08f3623597ce73b85d6854fb26608a6f39ee9d055c81178dc6583803797f8994", + "sha256:de2cbdd5926c48d7b84e0300dea9e8f276f61d186e8e49223d71d91250fbaebd" ], - "version": "==16.7.9" + "version": "==20.0.4" }, "zipp": { "hashes": [ - "sha256:ccc94ed0909b58ffe34430ea5451f07bc0c76467d7081619a454bf5c98b89e28", - "sha256:feae2f18633c32fc71f2de629bfb3bd3c9325cd4419642b1f1da42ee488d9b98" + "sha256:12248a63bbdf7548f89cb4c7cda4681e537031eda29c02ea29674bc6854460c2", + "sha256:7c0f8e91abc0dc07a5068f315c52cb30c66bfbc581e5b50704c8a2f6ebae794a" ], - "version": "==2.1.0" + "version": "==3.0.0" } } } diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 0400ac4d2..874364a6f 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -30,7 +30,7 @@ jobs: - script: python -m flake8 displayName: 'Run linter' - - script: BOT_API_KEY=foo BOT_TOKEN=bar WOLFRAM_API_KEY=baz REDDIT_CLIENT_ID=spam REDDIT_SECRET=ham coverage run -m xmlrunner + - script: BOT_API_KEY=foo BOT_SENTRY_DSN=blah BOT_TOKEN=bar WOLFRAM_API_KEY=baz REDDIT_CLIENT_ID=spam REDDIT_SECRET=ham coverage run -m xmlrunner displayName: Run tests - script: coverage report -m && coverage xml -o coverage.xml diff --git a/bot/__init__.py b/bot/__init__.py index 789ace5c0..f7a410706 100644 --- a/bot/__init__.py +++ b/bot/__init__.py @@ -4,11 +4,8 @@ import sys from logging import Logger, StreamHandler, handlers from pathlib import Path -from logmatic import JsonFormatter - - -logging.TRACE = 5 -logging.addLevelName(logging.TRACE, "TRACE") +TRACE_LEVEL = logging.TRACE = 5 +logging.addLevelName(TRACE_LEVEL, "TRACE") def monkeypatch_trace(self: logging.Logger, msg: str, *args, **kwargs) -> None: @@ -20,75 +17,29 @@ def monkeypatch_trace(self: logging.Logger, msg: str, *args, **kwargs) -> None: logger.trace("Houston, we have an %s", "interesting problem", exc_info=1) """ - if self.isEnabledFor(logging.TRACE): - self._log(logging.TRACE, msg, args, **kwargs) + if self.isEnabledFor(TRACE_LEVEL): + self._log(TRACE_LEVEL, msg, args, **kwargs) Logger.trace = monkeypatch_trace -# Set up logging -logging_handlers = [] - -# We can't import this yet, so we have to define it ourselves -DEBUG_MODE = True if 'local' in os.environ.get("SITE_URL", "local") else False - -LOG_DIR = Path("logs") -LOG_DIR.mkdir(exist_ok=True) - -if DEBUG_MODE: - logging_handlers.append(StreamHandler(stream=sys.stdout)) - - json_handler = logging.FileHandler(filename=Path(LOG_DIR, "log.json"), mode="w") - json_handler.formatter = JsonFormatter() - logging_handlers.append(json_handler) -else: - - logfile = Path(LOG_DIR, "bot.log") - megabyte = 1048576 - - filehandler = handlers.RotatingFileHandler(logfile, maxBytes=(megabyte*5), backupCount=7) - logging_handlers.append(filehandler) - - json_handler = logging.StreamHandler(stream=sys.stdout) - json_handler.formatter = JsonFormatter() - logging_handlers.append(json_handler) - - -logging.basicConfig( - format="%(asctime)s Bot: | %(name)33s | %(levelname)8s | %(message)s", - datefmt="%b %d %H:%M:%S", - level=logging.TRACE if DEBUG_MODE else logging.INFO, - handlers=logging_handlers -) - -log = logging.getLogger(__name__) - - -for key, value in logging.Logger.manager.loggerDict.items(): - # Force all existing loggers to the correct level and handlers - # This happens long before we instantiate our loggers, so - # those should still have the expected level - - if key == "bot": - continue - - if not isinstance(value, logging.Logger): - # There might be some logging.PlaceHolder objects in there - continue +DEBUG_MODE = 'local' in os.environ.get("SITE_URL", "local") - if DEBUG_MODE: - value.setLevel(logging.DEBUG) - else: - value.setLevel(logging.INFO) +log_format = logging.Formatter("%(asctime)s | %(name)s | %(levelname)s | %(message)s") - for handler in value.handlers.copy(): - value.removeHandler(handler) +stream_handler = StreamHandler(stream=sys.stdout) +stream_handler.setFormatter(log_format) - for handler in logging_handlers: - value.addHandler(handler) +log_file = Path("logs", "bot.log") +log_file.parent.mkdir(exist_ok=True) +file_handler = handlers.RotatingFileHandler(log_file, maxBytes=5242880, backupCount=7) +file_handler.setFormatter(log_format) +root_log = logging.getLogger() +root_log.setLevel(TRACE_LEVEL if DEBUG_MODE else logging.INFO) +root_log.addHandler(stream_handler) +root_log.addHandler(file_handler) -# Silence irrelevant loggers -logging.getLogger("aio_pika").setLevel(logging.ERROR) -logging.getLogger("discord").setLevel(logging.ERROR) -logging.getLogger("websockets").setLevel(logging.ERROR) +logging.getLogger("discord").setLevel(logging.WARNING) +logging.getLogger("websockets").setLevel(logging.WARNING) +logging.getLogger(__name__) diff --git a/bot/__main__.py b/bot/__main__.py index 84bc7094b..3df477a6d 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -1,10 +1,23 @@ +import logging + import discord +import sentry_sdk from discord.ext.commands import when_mentioned_or +from sentry_sdk.integrations.logging import LoggingIntegration from bot import patches from bot.bot import Bot -from bot.constants import Bot as BotConfig, DEBUG_MODE +from bot.constants import Bot as BotConfig + +sentry_logging = LoggingIntegration( + level=logging.DEBUG, + event_level=logging.WARNING +) +sentry_sdk.init( + dsn=BotConfig.sentry_dsn, + integrations=[sentry_logging] +) bot = Bot( command_prefix=when_mentioned_or(BotConfig.prefix), @@ -18,6 +31,7 @@ bot.load_extension("bot.cogs.error_handler") bot.load_extension("bot.cogs.filtering") bot.load_extension("bot.cogs.logging") bot.load_extension("bot.cogs.security") +bot.load_extension("bot.cogs.config_verifier") # Commands, etc bot.load_extension("bot.cogs.antimalware") @@ -27,10 +41,8 @@ bot.load_extension("bot.cogs.clean") bot.load_extension("bot.cogs.extensions") bot.load_extension("bot.cogs.help") -# Only load this in production -if not DEBUG_MODE: - bot.load_extension("bot.cogs.doc") - bot.load_extension("bot.cogs.verification") +bot.load_extension("bot.cogs.doc") +bot.load_extension("bot.cogs.verification") # Feature cogs bot.load_extension("bot.cogs.alias") diff --git a/bot/api.py b/bot/api.py index 56db99828..e59916114 100644 --- a/bot/api.py +++ b/bot/api.py @@ -32,6 +32,11 @@ class ResponseCodeError(ValueError): class APIClient: """Django Site API wrapper.""" + # These are class attributes so they can be seen when being mocked for tests. + # See commit 22a55534ef13990815a6f69d361e2a12693075d5 for details. + session: Optional[aiohttp.ClientSession] = None + loop: asyncio.AbstractEventLoop = None + def __init__(self, loop: asyncio.AbstractEventLoop, **kwargs): auth_headers = { 'Authorization': f"Token {Keys.site_api}" @@ -42,7 +47,7 @@ class APIClient: else: kwargs['headers'] = auth_headers - self.session: Optional[aiohttp.ClientSession] = None + self.session = None self.loop = loop self._ready = asyncio.Event(loop=loop) @@ -85,43 +90,35 @@ class APIClient: response_text = await response.text() raise ResponseCodeError(response=response, response_text=response_text) - async def get(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict: - """Site API GET.""" + async def request(self, method: str, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict: + """Send an HTTP request to the site API and return the JSON response.""" await self._ready.wait() - async with self.session.get(self._url_for(endpoint), *args, **kwargs) as resp: + async with self.session.request(method.upper(), self._url_for(endpoint), **kwargs) as resp: await self.maybe_raise_for_status(resp, raise_for_status) return await resp.json() - async def patch(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict: - """Site API PATCH.""" - await self._ready.wait() + async def get(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict: + """Site API GET.""" + return await self.request("GET", endpoint, raise_for_status=raise_for_status, **kwargs) - async with self.session.patch(self._url_for(endpoint), *args, **kwargs) as resp: - await self.maybe_raise_for_status(resp, raise_for_status) - return await resp.json() + async def patch(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict: + """Site API PATCH.""" + return await self.request("PATCH", endpoint, raise_for_status=raise_for_status, **kwargs) - async def post(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict: + async def post(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict: """Site API POST.""" - await self._ready.wait() + return await self.request("POST", endpoint, raise_for_status=raise_for_status, **kwargs) - async with self.session.post(self._url_for(endpoint), *args, **kwargs) as resp: - await self.maybe_raise_for_status(resp, raise_for_status) - return await resp.json() - - async def put(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> dict: + async def put(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> dict: """Site API PUT.""" - await self._ready.wait() + return await self.request("PUT", endpoint, raise_for_status=raise_for_status, **kwargs) - async with self.session.put(self._url_for(endpoint), *args, **kwargs) as resp: - await self.maybe_raise_for_status(resp, raise_for_status) - return await resp.json() - - async def delete(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs) -> Optional[dict]: + async def delete(self, endpoint: str, *, raise_for_status: bool = True, **kwargs) -> Optional[dict]: """Site API DELETE.""" await self._ready.wait() - async with self.session.delete(self._url_for(endpoint), *args, **kwargs) as resp: + async with self.session.delete(self._url_for(endpoint), **kwargs) as resp: if resp.status == 204: return None @@ -141,77 +138,3 @@ def loop_is_running() -> bool: except RuntimeError: return False return True - - -class APILoggingHandler(logging.StreamHandler): - """Site API logging handler.""" - - def __init__(self, client: APIClient): - logging.StreamHandler.__init__(self) - self.client = client - - # internal batch of shipoff tasks that must not be scheduled - # on the event loop yet - scheduled when the event loop is ready. - self.queue = [] - - async def ship_off(self, payload: dict) -> None: - """Ship log payload to the logging API.""" - try: - await self.client.post('logs', json=payload) - except ResponseCodeError as err: - log.warning( - "Cannot send logging record to the site, got code %d.", - err.response.status, - extra={'via_handler': True} - ) - except Exception as err: - log.warning( - "Cannot send logging record to the site: %r", - err, - extra={'via_handler': True} - ) - - def emit(self, record: logging.LogRecord) -> None: - """ - Determine if a log record should be shipped to the logging API. - - If the asyncio event loop is not yet running, log records will instead be put in a queue - which will be consumed once the event loop is running. - - The following two conditions are set: - 1. Do not log anything below DEBUG (only applies to the monkeypatched `TRACE` level) - 2. Ignore log records originating from this logging handler itself to prevent infinite recursion - """ - if ( - record.levelno >= logging.DEBUG - and not record.__dict__.get('via_handler') - ): - payload = { - 'application': 'bot', - 'logger_name': record.name, - 'level': record.levelname.lower(), - 'module': record.module, - 'line': record.lineno, - 'message': self.format(record) - } - - task = self.ship_off(payload) - if not loop_is_running(): - self.queue.append(task) - else: - asyncio.create_task(task) - self.schedule_queued_tasks() - - def schedule_queued_tasks(self) -> None: - """Consume the queue and schedule the logging of each queued record.""" - for task in self.queue: - asyncio.create_task(task) - - if self.queue: - log.debug( - "Scheduled %d pending logging tasks.", - len(self.queue), - extra={'via_handler': True} - ) - - self.queue.clear() diff --git a/bot/bot.py b/bot/bot.py index 8f808272f..19b9035c4 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -1,11 +1,14 @@ +import asyncio import logging import socket from typing import Optional import aiohttp +import discord from discord.ext import commands from bot import api +from bot import constants log = logging.getLogger('bot') @@ -17,17 +20,17 @@ class Bot(commands.Bot): # Use asyncio for DNS resolution instead of threads so threads aren't spammed. # Use AF_INET as its socket family to prevent HTTPS related problems both locally # and in production. - self.connector = aiohttp.TCPConnector( + self._connector = aiohttp.TCPConnector( resolver=aiohttp.AsyncResolver(), family=socket.AF_INET, ) - super().__init__(*args, connector=self.connector, **kwargs) + super().__init__(*args, connector=self._connector, **kwargs) - self.http_session: Optional[aiohttp.ClientSession] = None - self.api_client = api.APIClient(loop=self.loop, connector=self.connector) + self._guild_available = asyncio.Event() - log.addHandler(api.APILoggingHandler(self.api_client)) + self.http_session: Optional[aiohttp.ClientSession] = None + self.api_client = api.APIClient(loop=self.loop, connector=self._connector) def add_cog(self, cog: commands.Cog) -> None: """Adds a "cog" to the bot and logs the operation.""" @@ -48,6 +51,47 @@ class Bot(commands.Bot): async def start(self, *args, **kwargs) -> None: """Open an aiohttp session before logging in and connecting to Discord.""" - self.http_session = aiohttp.ClientSession(connector=self.connector) + self.http_session = aiohttp.ClientSession(connector=self._connector) await super().start(*args, **kwargs) + + async def on_guild_available(self, guild: discord.Guild) -> None: + """ + Set the internal guild available event when constants.Guild.id becomes available. + + If the cache appears to still be empty (no members, no channels, or no roles), the event + will not be set. + """ + if guild.id != constants.Guild.id: + return + + if not guild.roles or not guild.members or not guild.channels: + msg = "Guild available event was dispatched but the cache appears to still be empty!" + log.warning(msg) + + try: + webhook = await self.fetch_webhook(constants.Webhooks.dev_log) + except discord.HTTPException as e: + log.error(f"Failed to fetch webhook to send empty cache warning: status {e.status}") + else: + await webhook.send(f"<@&{constants.Roles.admin}> {msg}") + + return + + self._guild_available.set() + + async def on_guild_unavailable(self, guild: discord.Guild) -> None: + """Clear the internal guild available event when constants.Guild.id becomes unavailable.""" + if guild.id != constants.Guild.id: + return + + self._guild_available.clear() + + async def wait_until_guild_available(self) -> None: + """ + Wait until the constants.Guild.id guild is available (and the cache is ready). + + The on_ready event is inadequate because it only waits 2 seconds for a GUILD_CREATE + gateway event before giving up and thus not populating the cache for unavailable guilds. + """ + await self._guild_available.wait() diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py index 28e3e5d96..9e9e81364 100644 --- a/bot/cogs/antimalware.py +++ b/bot/cogs/antimalware.py @@ -4,7 +4,7 @@ from discord import Embed, Message, NotFound from discord.ext.commands import Cog from bot.bot import Bot -from bot.constants import AntiMalware as AntiMalwareConfig, Channels, URLs +from bot.constants import AntiMalware as AntiMalwareConfig, Channels, STAFF_ROLES, URLs log = logging.getLogger(__name__) @@ -18,7 +18,13 @@ class AntiMalware(Cog): @Cog.listener() async def on_message(self, message: Message) -> None: """Identify messages with prohibited attachments.""" - if not message.attachments: + # Return when message don't have attachment and don't moderate DMs + if not message.attachments or not message.guild: + return + + # Check if user is staff, if is, return + # Since we only care that roles exist to iterate over, check for the attr rather than a User/Member instance + if hasattr(message.author, "roles") and any(role.id in STAFF_ROLES for role in message.author.roles): return embed = Embed() diff --git a/bot/cogs/antispam.py b/bot/cogs/antispam.py index f67ef6f05..baa6b9459 100644 --- a/bot/cogs/antispam.py +++ b/bot/cogs/antispam.py @@ -123,7 +123,7 @@ class AntiSpam(Cog): async def alert_on_validation_error(self) -> None: """Unloads the cog and alerts admins if configuration validation failed.""" - await self.bot.wait_until_ready() + await self.bot.wait_until_guild_available() if self.validation_errors: body = "**The following errors were encountered:**\n" body += "\n".join(f"- {error}" for error in self.validation_errors.values()) diff --git a/bot/cogs/bot.py b/bot/cogs/bot.py index 73b1e8f41..f17135877 100644 --- a/bot/cogs/bot.py +++ b/bot/cogs/bot.py @@ -34,13 +34,12 @@ class BotCog(Cog, name="Bot"): Channels.help_5: 0, Channels.help_6: 0, Channels.help_7: 0, - Channels.python: 0, + Channels.python_discussion: 0, } # These channels will also work, but will not be subject to cooldown self.channel_whitelist = ( - Channels.bot, - Channels.devtest, + Channels.bot_commands, ) # Stores improperly formatted Python codeblock message ids and the corresponding bot message diff --git a/bot/cogs/clean.py b/bot/cogs/clean.py index 2104efe57..5cdf0b048 100644 --- a/bot/cogs/clean.py +++ b/bot/cogs/clean.py @@ -173,7 +173,7 @@ class Clean(Cog): colour=Colour(Colours.soft_red), title="Bulk message delete", text=message, - channel_id=Channels.modlog, + channel_id=Channels.mod_log, ) @group(invoke_without_command=True, name="clean", aliases=["purge"]) diff --git a/bot/cogs/config_verifier.py b/bot/cogs/config_verifier.py new file mode 100644 index 000000000..d72c6c22e --- /dev/null +++ b/bot/cogs/config_verifier.py @@ -0,0 +1,40 @@ +import logging + +from discord.ext.commands import Cog + +from bot import constants +from bot.bot import Bot + + +log = logging.getLogger(__name__) + + +class ConfigVerifier(Cog): + """Verify config on startup.""" + + def __init__(self, bot: Bot): + self.bot = bot + self.channel_verify_task = self.bot.loop.create_task(self.verify_channels()) + + async def verify_channels(self) -> None: + """ + Verify channels. + + If any channels in config aren't present in server, log them in a warning. + """ + await self.bot.wait_until_guild_available() + server = self.bot.get_guild(constants.Guild.id) + + server_channel_ids = {channel.id for channel in server.channels} + invalid_channels = [ + channel_name for channel_name, channel_id in constants.Channels + if channel_id not in server_channel_ids + ] + + if invalid_channels: + log.warning(f"Configured channels do not exist in server: {', '.join(invalid_channels)}.") + + +def setup(bot: Bot) -> None: + """Load the ConfigVerifier cog.""" + bot.add_cog(ConfigVerifier(bot)) diff --git a/bot/cogs/defcon.py b/bot/cogs/defcon.py index 3e7350fcc..cc0f79fe8 100644 --- a/bot/cogs/defcon.py +++ b/bot/cogs/defcon.py @@ -59,7 +59,7 @@ class Defcon(Cog): async def sync_settings(self) -> None: """On cog load, try to synchronize DEFCON settings to the API.""" - await self.bot.wait_until_ready() + await self.bot.wait_until_guild_available() self.channel = await self.bot.fetch_channel(Channels.defcon) try: @@ -68,20 +68,20 @@ class Defcon(Cog): except Exception: # Yikes! log.exception("Unable to get DEFCON settings!") - await self.bot.get_channel(Channels.devlog).send( - f"<@&{Roles.admin}> **WARNING**: Unable to get DEFCON settings!" + await self.bot.get_channel(Channels.dev_log).send( + f"<@&{Roles.admins}> **WARNING**: Unable to get DEFCON settings!" ) else: if data["enabled"]: self.enabled = True self.days = timedelta(days=data["days"]) - log.warning(f"DEFCON enabled: {self.days.days} days") + log.info(f"DEFCON enabled: {self.days.days} days") else: self.enabled = False self.days = timedelta(days=0) - log.warning(f"DEFCON disabled") + log.info(f"DEFCON disabled") await self.update_channel_topic() @@ -118,7 +118,7 @@ class Defcon(Cog): ) @group(name='defcon', aliases=('dc',), invoke_without_command=True) - @with_role(Roles.admin, Roles.owner) + @with_role(Roles.admins, Roles.owners) async def defcon_group(self, ctx: Context) -> None: """Check the DEFCON status or run a subcommand.""" await ctx.invoke(self.bot.get_command("help"), "defcon") @@ -146,7 +146,7 @@ class Defcon(Cog): await self.send_defcon_log(action, ctx.author, error) @defcon_group.command(name='enable', aliases=('on', 'e')) - @with_role(Roles.admin, Roles.owner) + @with_role(Roles.admins, Roles.owners) async def enable_command(self, ctx: Context) -> None: """ Enable DEFCON mode. Useful in a pinch, but be sure you know what you're doing! @@ -159,7 +159,7 @@ class Defcon(Cog): await self.update_channel_topic() @defcon_group.command(name='disable', aliases=('off', 'd')) - @with_role(Roles.admin, Roles.owner) + @with_role(Roles.admins, Roles.owners) async def disable_command(self, ctx: Context) -> None: """Disable DEFCON mode. Useful in a pinch, but be sure you know what you're doing!""" self.enabled = False @@ -167,7 +167,7 @@ class Defcon(Cog): await self.update_channel_topic() @defcon_group.command(name='status', aliases=('s',)) - @with_role(Roles.admin, Roles.owner) + @with_role(Roles.admins, Roles.owners) async def status_command(self, ctx: Context) -> None: """Check the current status of DEFCON mode.""" embed = Embed( @@ -179,7 +179,7 @@ class Defcon(Cog): await ctx.send(embed=embed) @defcon_group.command(name='days') - @with_role(Roles.admin, Roles.owner) + @with_role(Roles.admins, Roles.owners) async def days_command(self, ctx: Context, days: int) -> None: """Set how old an account must be to join the server, in days, with DEFCON mode enabled.""" self.days = timedelta(days=days) diff --git a/bot/cogs/doc.py b/bot/cogs/doc.py index 6e7c00b6a..204cffb37 100644 --- a/bot/cogs/doc.py +++ b/bot/cogs/doc.py @@ -157,7 +157,7 @@ class Doc(commands.Cog): async def init_refresh_inventory(self) -> None: """Refresh documentation inventory on cog initialization.""" - await self.bot.wait_until_ready() + await self.bot.wait_until_guild_available() await self.refresh_inventory() async def update_single( diff --git a/bot/cogs/duck_pond.py b/bot/cogs/duck_pond.py index 345d2856c..1f84a0609 100644 --- a/bot/cogs/duck_pond.py +++ b/bot/cogs/duck_pond.py @@ -22,7 +22,7 @@ class DuckPond(Cog): async def fetch_webhook(self) -> None: """Fetches the webhook object, so we can post to it.""" - await self.bot.wait_until_ready() + await self.bot.wait_until_guild_available() try: self.webhook = await self.bot.fetch_webhook(self.webhook_id) diff --git a/bot/cogs/error_handler.py b/bot/cogs/error_handler.py index 52893b2ee..864450139 100644 --- a/bot/cogs/error_handler.py +++ b/bot/cogs/error_handler.py @@ -1,20 +1,9 @@ import contextlib import logging +import typing as t -from discord.ext.commands import ( - BadArgument, - BotMissingPermissions, - CheckFailure, - CommandError, - CommandInvokeError, - CommandNotFound, - CommandOnCooldown, - DisabledCommand, - MissingPermissions, - NoPrivateMessage, - UserInputError, -) -from discord.ext.commands import Cog, Context +from discord.ext.commands import Cog, Command, Context, errors +from sentry_sdk import push_scope from bot.api import ResponseCodeError from bot.bot import Bot @@ -31,126 +20,201 @@ class ErrorHandler(Cog): self.bot = bot @Cog.listener() - async def on_command_error(self, ctx: Context, e: CommandError) -> None: + async def on_command_error(self, ctx: Context, e: errors.CommandError) -> None: """ Provide generic command error handling. - Error handling is deferred to any local error handler, if present. - - Error handling emits a single error response, prioritized as follows: - 1. If the name fails to match a command but matches a tag, the tag is invoked - 2. Send a BadArgument error message to the invoking context & invoke the command's help - 3. Send a UserInputError error message to the invoking context & invoke the command's help - 4. Send a NoPrivateMessage error message to the invoking context - 5. Send a BotMissingPermissions error message to the invoking context - 6. Log a MissingPermissions error, no message is sent - 7. Send a InChannelCheckFailure error message to the invoking context - 8. Log CheckFailure, CommandOnCooldown, and DisabledCommand errors, no message is sent - 9. For CommandInvokeErrors, response is based on the type of error: - * 404: Error message is sent to the invoking context - * 400: Log the resopnse JSON, no message is sent - * 500 <= status <= 600: Error message is sent to the invoking context - 10. Otherwise, handling is deferred to `handle_unexpected_error` + Error handling is deferred to any local error handler, if present. This is done by + checking for the presence of a `handled` attribute on the error. + + Error handling emits a single error message in the invoking context `ctx` and a log message, + prioritised as follows: + + 1. If the name fails to match a command but matches a tag, the tag is invoked + * If CommandNotFound is raised when invoking the tag (determined by the presence of the + `invoked_from_error_handler` attribute), this error is treated as being unexpected + and therefore sends an error message + * Commands in the verification channel are ignored + 2. UserInputError: see `handle_user_input_error` + 3. CheckFailure: see `handle_check_failure` + 4. CommandOnCooldown: send an error message in the invoking context + 5. ResponseCodeError: see `handle_api_error` + 6. Otherwise, if not a DisabledCommand, handling is deferred to `handle_unexpected_error` """ command = ctx.command - parent = None + if hasattr(e, "handled"): + log.trace(f"Command {command} had its error already handled locally; ignoring.") + return + + # Try to look for a tag with the command's name if the command isn't found. + if isinstance(e, errors.CommandNotFound) and not hasattr(ctx, "invoked_from_error_handler"): + if ctx.channel.id != Channels.verification: + await self.try_get_tag(ctx) + return # Exit early to avoid logging. + elif isinstance(e, errors.UserInputError): + await self.handle_user_input_error(ctx, e) + elif isinstance(e, errors.CheckFailure): + await self.handle_check_failure(ctx, e) + elif isinstance(e, errors.CommandOnCooldown): + await ctx.send(e) + elif isinstance(e, errors.CommandInvokeError): + if isinstance(e.original, ResponseCodeError): + await self.handle_api_error(ctx, e.original) + else: + await self.handle_unexpected_error(ctx, e.original) + return # Exit early to avoid logging. + elif not isinstance(e, errors.DisabledCommand): + # ConversionError, MaxConcurrencyReached, ExtensionError + await self.handle_unexpected_error(ctx, e) + return # Exit early to avoid logging. + + log.debug( + f"Command {command} invoked by {ctx.message.author} with error " + f"{e.__class__.__name__}: {e}" + ) + + async def get_help_command(self, command: t.Optional[Command]) -> t.Tuple: + """Return the help command invocation args to display help for `command`.""" + parent = None if command is not None: parent = command.parent # Retrieve the help command for the invoked command. if parent and command: - help_command = (self.bot.get_command("help"), parent.name, command.name) + return self.bot.get_command("help"), parent.name, command.name elif command: - help_command = (self.bot.get_command("help"), command.name) + return self.bot.get_command("help"), command.name else: - help_command = (self.bot.get_command("help"),) + return self.bot.get_command("help") - if hasattr(e, "handled"): - log.trace(f"Command {command} had its error already handled locally; ignoring.") + async def try_get_tag(self, ctx: Context) -> None: + """ + Attempt to display a tag by interpreting the command name as a tag name. + + The invocation of tags get respects its checks. Any CommandErrors raised will be handled + by `on_command_error`, but the `invoked_from_error_handler` attribute will be added to + the context to prevent infinite recursion in the case of a CommandNotFound exception. + """ + tags_get_command = self.bot.get_command("tags get") + ctx.invoked_from_error_handler = True + + log_msg = "Cancelling attempt to fall back to a tag due to failed checks." + try: + if not await tags_get_command.can_run(ctx): + log.debug(log_msg) + return + except errors.CommandError as tag_error: + log.debug(log_msg) + await self.on_command_error(ctx, tag_error) return - # Try to look for a tag with the command's name if the command isn't found. - if isinstance(e, CommandNotFound) and not hasattr(ctx, "invoked_from_error_handler"): - if not ctx.channel.id == Channels.verification: - tags_get_command = self.bot.get_command("tags get") - ctx.invoked_from_error_handler = True - - log_msg = "Cancelling attempt to fall back to a tag due to failed checks." - try: - if not await tags_get_command.can_run(ctx): - log.debug(log_msg) - return - except CommandError as tag_error: - log.debug(log_msg) - await self.on_command_error(ctx, tag_error) - return - - # Return to not raise the exception - with contextlib.suppress(ResponseCodeError): - await ctx.invoke(tags_get_command, tag_name=ctx.invoked_with) - return - elif isinstance(e, BadArgument): + # Return to not raise the exception + with contextlib.suppress(ResponseCodeError): + await ctx.invoke(tags_get_command, tag_name=ctx.invoked_with) + return + + async def handle_user_input_error(self, ctx: Context, e: errors.UserInputError) -> None: + """ + Send an error message in `ctx` for UserInputError, sometimes invoking the help command too. + + * MissingRequiredArgument: send an error message with arg name and the help command + * TooManyArguments: send an error message and the help command + * BadArgument: send an error message and the help command + * BadUnionArgument: send an error message including the error produced by the last converter + * ArgumentParsingError: send an error message + * Other: send an error message and the help command + """ + # TODO: use ctx.send_help() once PR #519 is merged. + help_command = await self.get_help_command(ctx.command) + + if isinstance(e, errors.MissingRequiredArgument): + await ctx.send(f"Missing required argument `{e.param.name}`.") + await ctx.invoke(*help_command) + elif isinstance(e, errors.TooManyArguments): + await ctx.send(f"Too many arguments provided.") + await ctx.invoke(*help_command) + elif isinstance(e, errors.BadArgument): await ctx.send(f"Bad argument: {e}\n") await ctx.invoke(*help_command) - elif isinstance(e, UserInputError): + elif isinstance(e, errors.BadUnionArgument): + await ctx.send(f"Bad argument: {e}\n```{e.errors[-1]}```") + elif isinstance(e, errors.ArgumentParsingError): + await ctx.send(f"Argument parsing error: {e}") + else: await ctx.send("Something about your input seems off. Check the arguments:") await ctx.invoke(*help_command) - log.debug( - f"Command {command} invoked by {ctx.message.author} with error " - f"{e.__class__.__name__}: {e}" - ) - elif isinstance(e, NoPrivateMessage): - await ctx.send("Sorry, this command can't be used in a private message!") - elif isinstance(e, BotMissingPermissions): - await ctx.send(f"Sorry, it looks like I don't have the permissions I need to do that.") - log.warning( - f"The bot is missing permissions to execute command {command}: {e.missing_perms}" - ) - elif isinstance(e, MissingPermissions): - log.debug( - f"{ctx.message.author} is missing permissions to invoke command {command}: " - f"{e.missing_perms}" + + @staticmethod + async def handle_check_failure(ctx: Context, e: errors.CheckFailure) -> None: + """ + Send an error message in `ctx` for certain types of CheckFailure. + + The following types are handled: + + * BotMissingPermissions + * BotMissingRole + * BotMissingAnyRole + * NoPrivateMessage + * InChannelCheckFailure + """ + bot_missing_errors = ( + errors.BotMissingPermissions, + errors.BotMissingRole, + errors.BotMissingAnyRole + ) + + if isinstance(e, bot_missing_errors): + await ctx.send( + f"Sorry, it looks like I don't have the permissions or roles I need to do that." ) - elif isinstance(e, InChannelCheckFailure): + elif isinstance(e, (InChannelCheckFailure, errors.NoPrivateMessage)): await ctx.send(e) - elif isinstance(e, (CheckFailure, CommandOnCooldown, DisabledCommand)): - log.debug( - f"Command {command} invoked by {ctx.message.author} with error " - f"{e.__class__.__name__}: {e}" - ) - elif isinstance(e, CommandInvokeError): - if isinstance(e.original, ResponseCodeError): - status = e.original.response.status - - if status == 404: - await ctx.send("There does not seem to be anything matching your query.") - elif status == 400: - content = await e.original.response.json() - log.debug(f"API responded with 400 for command {command}: %r.", content) - await ctx.send("According to the API, your request is malformed.") - elif 500 <= status < 600: - await ctx.send("Sorry, there seems to be an internal issue with the API.") - log.warning(f"API responded with {status} for command {command}") - else: - await ctx.send(f"Got an unexpected status code from the API (`{status}`).") - log.warning(f"Unexpected API response for command {command}: {status}") - else: - await self.handle_unexpected_error(ctx, e.original) + + @staticmethod + async def handle_api_error(ctx: Context, e: ResponseCodeError) -> None: + """Send an error message in `ctx` for ResponseCodeError and log it.""" + if e.status == 404: + await ctx.send("There does not seem to be anything matching your query.") + log.debug(f"API responded with 404 for command {ctx.command}") + elif e.status == 400: + content = await e.response.json() + log.debug(f"API responded with 400 for command {ctx.command}: %r.", content) + await ctx.send("According to the API, your request is malformed.") + elif 500 <= e.status < 600: + await ctx.send("Sorry, there seems to be an internal issue with the API.") + log.warning(f"API responded with {e.status} for command {ctx.command}") else: - await self.handle_unexpected_error(ctx, e) + await ctx.send(f"Got an unexpected status code from the API (`{e.status}`).") + log.warning(f"Unexpected API response for command {ctx.command}: {e.status}") @staticmethod - async def handle_unexpected_error(ctx: Context, e: CommandError) -> None: - """Generic handler for errors without an explicit handler.""" + async def handle_unexpected_error(ctx: Context, e: errors.CommandError) -> None: + """Send a generic error message in `ctx` and log the exception as an error with exc_info.""" await ctx.send( f"Sorry, an unexpected error occurred. Please let us know!\n\n" f"```{e.__class__.__name__}: {e}```" ) - log.error( - f"Error executing command invoked by {ctx.message.author}: {ctx.message.content}" - ) - raise e + + with push_scope() as scope: + scope.user = { + "id": ctx.author.id, + "username": str(ctx.author) + } + + scope.set_tag("command", ctx.command.qualified_name) + scope.set_tag("message_id", ctx.message.id) + scope.set_tag("channel_id", ctx.channel.id) + + scope.set_extra("full_message", ctx.message.content) + + if ctx.guild is not None: + scope.set_extra( + "jump_to", + f"https://discordapp.com/channels/{ctx.guild.id}/{ctx.channel.id}/{ctx.message.id}" + ) + + log.error(f"Error executing command invoked by {ctx.message.author}: {ctx.message.content}", exc_info=e) def setup(bot: Bot) -> None: diff --git a/bot/cogs/eval.py b/bot/cogs/eval.py index 9c729f28a..52136fc8d 100644 --- a/bot/cogs/eval.py +++ b/bot/cogs/eval.py @@ -174,14 +174,14 @@ async def func(): # (None,) -> Any await ctx.send(f"```py\n{out}```", embed=embed) @group(name='internal', aliases=('int',)) - @with_role(Roles.owner, Roles.admin) + @with_role(Roles.owners, Roles.admins) async def internal_group(self, ctx: Context) -> None: """Internal commands. Top secret!""" if not ctx.invoked_subcommand: await ctx.invoke(self.bot.get_command("help"), "internal") @internal_group.command(name='eval', aliases=('e',)) - @with_role(Roles.admin, Roles.owner) + @with_role(Roles.admins, Roles.owners) async def eval(self, ctx: Context, *, code: str) -> None: """Run eval in a REPL-like format.""" code = code.strip("`") diff --git a/bot/cogs/extensions.py b/bot/cogs/extensions.py index f16e79fb7..b312e1a1d 100644 --- a/bot/cogs/extensions.py +++ b/bot/cogs/extensions.py @@ -221,7 +221,7 @@ class Extensions(commands.Cog): # This cannot be static (must have a __func__ attribute). def cog_check(self, ctx: Context) -> bool: """Only allow moderators and core developers to invoke the commands in this cog.""" - return with_role_check(ctx, *MODERATION_ROLES, Roles.core_developer) + return with_role_check(ctx, *MODERATION_ROLES, Roles.core_developers) # This cannot be static (must have a __func__ attribute). async def cog_command_error(self, ctx: Context, error: Exception) -> None: diff --git a/bot/cogs/free.py b/bot/cogs/free.py index 49cab6172..02c02d067 100644 --- a/bot/cogs/free.py +++ b/bot/cogs/free.py @@ -22,7 +22,7 @@ class Free(Cog): PYTHON_HELP_ID = Categories.python_help @command(name="free", aliases=('f',)) - @redirect_output(destination_channel=Channels.bot, bypass_roles=STAFF_ROLES) + @redirect_output(destination_channel=Channels.bot_commands, bypass_roles=STAFF_ROLES) async def free(self, ctx: Context, user: Member = None, seek: int = 2) -> None: """ Lists free help channels by likeliness of availability. diff --git a/bot/cogs/help.py b/bot/cogs/help.py index fd5bbc3ca..744722220 100644 --- a/bot/cogs/help.py +++ b/bot/cogs/help.py @@ -507,7 +507,7 @@ class Help(DiscordCog): """Custom Embed Pagination Help feature.""" @commands.command('help') - @redirect_output(destination_channel=Channels.bot, bypass_roles=STAFF_ROLES) + @redirect_output(destination_channel=Channels.bot_commands, bypass_roles=STAFF_ROLES) async def new_help(self, ctx: Context, *commands) -> None: """Shows Command Help.""" try: diff --git a/bot/cogs/information.py b/bot/cogs/information.py index 125d7ce24..49beca15b 100644 --- a/bot/cogs/information.py +++ b/bot/cogs/information.py @@ -2,14 +2,12 @@ import colorsys import logging import pprint import textwrap -import typing -from collections import defaultdict -from typing import Any, Mapping, Optional - -import discord -from discord import CategoryChannel, Colour, Embed, Member, Role, TextChannel, VoiceChannel, utils -from discord.ext import commands -from discord.ext.commands import BucketType, Cog, Context, command, group +from collections import Counter, defaultdict +from string import Template +from typing import Any, Mapping, Optional, Union + +from discord import Colour, Embed, Member, Message, Role, Status, utils +from discord.ext.commands import BucketType, Cog, Context, Paginator, command, group from discord.utils import escape_markdown from bot import constants @@ -32,8 +30,7 @@ class Information(Cog): async def roles_info(self, ctx: Context) -> None: """Returns a list of all roles and their corresponding IDs.""" # Sort the roles alphabetically and remove the @everyone role - roles = sorted(ctx.guild.roles, key=lambda role: role.name) - roles = [role for role in roles if role.name != "@everyone"] + roles = sorted(ctx.guild.roles[1:], key=lambda role: role.name) # Build a string role_string = "" @@ -46,20 +43,20 @@ class Information(Cog): colour=Colour.blurple(), description=role_string ) - embed.set_footer(text=f"Total roles: {len(roles)}") await ctx.send(embed=embed) @with_role(*constants.MODERATION_ROLES) @command(name="role") - async def role_info(self, ctx: Context, *roles: typing.Union[Role, str]) -> None: + async def role_info(self, ctx: Context, *roles: Union[Role, str]) -> None: """ Return information on a role or list of roles. To specify multiple roles just add to the arguments, delimit roles with spaces in them using quotation marks. """ parsed_roles = [] + failed_roles = [] for role_name in roles: if isinstance(role_name, Role): @@ -70,29 +67,29 @@ class Information(Cog): role = utils.find(lambda r: r.name.lower() == role_name.lower(), ctx.guild.roles) if not role: - await ctx.send(f":x: Could not convert `{role_name}` to a role") + failed_roles.append(role_name) continue parsed_roles.append(role) + if failed_roles: + await ctx.send( + ":x: I could not convert the following role names to a role: \n- " + "\n- ".join(failed_roles) + ) + for role in parsed_roles: + h, s, v = colorsys.rgb_to_hsv(*role.colour.to_rgb()) + embed = Embed( title=f"{role.name} info", colour=role.colour, ) - embed.add_field(name="ID", value=role.id, inline=True) - embed.add_field(name="Colour (RGB)", value=f"#{role.colour.value:0>6x}", inline=True) - - h, s, v = colorsys.rgb_to_hsv(*role.colour.to_rgb()) - embed.add_field(name="Colour (HSV)", value=f"{h:.2f} {s:.2f} {v}", inline=True) - embed.add_field(name="Member count", value=len(role.members), inline=True) - embed.add_field(name="Position", value=role.position) - embed.add_field(name="Permission code", value=role.permissions.value, inline=True) await ctx.send(embed=embed) @@ -104,40 +101,23 @@ class Information(Cog): features = ", ".join(ctx.guild.features) region = ctx.guild.region - # How many of each type of channel? roles = len(ctx.guild.roles) - channels = ctx.guild.channels - text_channels = 0 - category_channels = 0 - voice_channels = 0 - for channel in channels: - if type(channel) == TextChannel: - text_channels += 1 - elif type(channel) == CategoryChannel: - category_channels += 1 - elif type(channel) == VoiceChannel: - voice_channels += 1 - - # How many of each user status? member_count = ctx.guild.member_count - members = ctx.guild.members - online = 0 - dnd = 0 - idle = 0 - offline = 0 - for member in members: - if str(member.status) == "online": - online += 1 - elif str(member.status) == "offline": - offline += 1 - elif str(member.status) == "idle": - idle += 1 - elif str(member.status) == "dnd": - dnd += 1 - embed = Embed( - colour=Colour.blurple(), - description=textwrap.dedent(f""" + # How many of each type of channel? + channels = Counter(c.type for c in ctx.guild.channels) + channel_counts = "".join(sorted(f"{str(ch).title()} channels: {channels[ch]}\n" for ch in channels)).strip() + + # How many of each user status? + statuses = Counter(member.status for member in ctx.guild.members) + embed = Embed(colour=Colour.blurple()) + + # Because channel_counts lacks leading whitespace, it breaks the dedent if it's inserted directly by the + # f-string. While this is correctly formated by Discord, it makes unit testing difficult. To keep the formatting + # without joining a tuple of strings we can use a Template string to insert the already-formatted channel_counts + # after the dedent is made. + embed.description = Template( + textwrap.dedent(f""" **Server information** Created: {created} Voice region: {region} @@ -146,18 +126,15 @@ class Information(Cog): **Counts** Members: {member_count:,} Roles: {roles} - Text: {text_channels} - Voice: {voice_channels} - Channel categories: {category_channels} + $channel_counts **Members** - {constants.Emojis.status_online} {online} - {constants.Emojis.status_idle} {idle} - {constants.Emojis.status_dnd} {dnd} - {constants.Emojis.status_offline} {offline} + {constants.Emojis.status_online} {statuses[Status.online]:,} + {constants.Emojis.status_idle} {statuses[Status.idle]:,} + {constants.Emojis.status_dnd} {statuses[Status.dnd]:,} + {constants.Emojis.status_offline} {statuses[Status.offline]:,} """) - ) - + ).substitute({"channel_counts": channel_counts}) embed.set_thumbnail(url=ctx.guild.icon_url) await ctx.send(embed=embed) @@ -169,14 +146,14 @@ class Information(Cog): user = ctx.author # Do a role check if this is being executed on someone other than the caller - if user != ctx.author and not with_role_check(ctx, *constants.MODERATION_ROLES): + elif user != ctx.author and not with_role_check(ctx, *constants.MODERATION_ROLES): await ctx.send("You may not use this command on users other than yourself.") return # Non-staff may only do this in #bot-commands if not with_role_check(ctx, *constants.STAFF_ROLES): - if not ctx.channel.id == constants.Channels.bot: - raise InChannelCheckFailure(constants.Channels.bot) + if not ctx.channel.id == constants.Channels.bot_commands: + raise InChannelCheckFailure(constants.Channels.bot_commands) embed = await self.create_user_embed(ctx, user) @@ -202,7 +179,7 @@ class Information(Cog): name = f"{user.nick} ({name})" joined = time_since(user.joined_at, precision="days") - roles = ", ".join(role.mention for role in user.roles if role.name != "@everyone") + roles = ", ".join(role.mention for role in user.roles[1:]) description = [ textwrap.dedent(f""" @@ -355,14 +332,14 @@ class Information(Cog): @cooldown_with_role_bypass(2, 60 * 3, BucketType.member, bypass_roles=constants.STAFF_ROLES) @group(invoke_without_command=True) - @in_channel(constants.Channels.bot, bypass_roles=constants.STAFF_ROLES) - async def raw(self, ctx: Context, *, message: discord.Message, json: bool = False) -> None: + @in_channel(constants.Channels.bot_commands, bypass_roles=constants.STAFF_ROLES) + async def raw(self, ctx: Context, *, message: Message, json: bool = False) -> None: """Shows information about the raw API response.""" # I *guess* it could be deleted right as the command is invoked but I felt like it wasn't worth handling # doing this extra request is also much easier than trying to convert everything back into a dictionary again raw_data = await ctx.bot.http.get_message(message.channel.id, message.id) - paginator = commands.Paginator() + paginator = Paginator() def add_content(title: str, content: str) -> None: paginator.add_line(f'== {title} ==\n') @@ -390,7 +367,7 @@ class Information(Cog): await ctx.send(page) @raw.command() - async def json(self, ctx: Context, message: discord.Message) -> None: + async def json(self, ctx: Context, message: Message) -> None: """Shows information about the raw API response in a copy-pasteable Python format.""" await ctx.invoke(self.raw, message=message, json=True) diff --git a/bot/cogs/jams.py b/bot/cogs/jams.py index 985f28ce5..1d062b0c2 100644 --- a/bot/cogs/jams.py +++ b/bot/cogs/jams.py @@ -18,7 +18,7 @@ class CodeJams(commands.Cog): self.bot = bot @commands.command() - @with_role(Roles.admin) + @with_role(Roles.admins) async def createteam(self, ctx: commands.Context, team_name: str, members: commands.Greedy[Member]) -> None: """ Create team channels (voice and text) in the Code Jams category, assign roles, and add overwrites for the team. @@ -95,10 +95,10 @@ class CodeJams(commands.Cog): ) # Assign team leader role - await members[0].add_roles(ctx.guild.get_role(Roles.team_leader)) + await members[0].add_roles(ctx.guild.get_role(Roles.team_leaders)) # Assign rest of roles - jammer_role = ctx.guild.get_role(Roles.jammer) + jammer_role = ctx.guild.get_role(Roles.jammers) for member in members: await member.add_roles(jammer_role) diff --git a/bot/cogs/logging.py b/bot/cogs/logging.py index d1b7dcab3..94fa2b139 100644 --- a/bot/cogs/logging.py +++ b/bot/cogs/logging.py @@ -20,7 +20,7 @@ class Logging(Cog): async def startup_greeting(self) -> None: """Announce our presence to the configured devlog channel.""" - await self.bot.wait_until_ready() + await self.bot.wait_until_guild_available() log.info("Bot connected!") embed = Embed(description="Connected!") @@ -34,7 +34,7 @@ class Logging(Cog): ) if not DEBUG_MODE: - await self.bot.get_channel(Channels.devlog).send(embed=embed) + await self.bot.get_channel(Channels.dev_log).send(embed=embed) def setup(bot: Bot) -> None: diff --git a/bot/cogs/moderation/infractions.py b/bot/cogs/moderation/infractions.py index f4e296df9..9ea17b2b3 100644 --- a/bot/cogs/moderation/infractions.py +++ b/bot/cogs/moderation/infractions.py @@ -313,6 +313,6 @@ class Infractions(InfractionScheduler, commands.Cog): async def cog_command_error(self, ctx: Context, error: Exception) -> None: """Send a notification to the invoking context on a Union failure.""" if isinstance(error, commands.BadUnionArgument): - if discord.User in error.converters: + if discord.User in error.converters or discord.Member in error.converters: await ctx.send(str(error.errors[0])) error.handled = True diff --git a/bot/cogs/moderation/management.py b/bot/cogs/moderation/management.py index f2964cd78..f74089056 100644 --- a/bot/cogs/moderation/management.py +++ b/bot/cogs/moderation/management.py @@ -129,7 +129,9 @@ class ModManagement(commands.Cog): # Re-schedule infraction if the expiration has been updated if 'expires_at' in request_data: - self.infractions_cog.cancel_task(new_infraction['id']) + # A scheduled task should only exist if the old infraction wasn't permanent + if old_infraction['expires_at']: + self.infractions_cog.cancel_task(new_infraction['id']) # If the infraction was not marked as permanent, schedule a new expiration task if request_data['expires_at']: diff --git a/bot/cogs/moderation/modlog.py b/bot/cogs/moderation/modlog.py index e8ae0dbe6..59ae6b587 100644 --- a/bot/cogs/moderation/modlog.py +++ b/bot/cogs/moderation/modlog.py @@ -87,7 +87,7 @@ class ModLog(Cog, name="ModLog"): title: t.Optional[str], text: str, thumbnail: t.Optional[t.Union[str, discord.Asset]] = None, - channel_id: int = Channels.modlog, + channel_id: int = Channels.mod_log, ping_everyone: bool = False, files: t.Optional[t.List[discord.File]] = None, content: t.Optional[str] = None, @@ -377,7 +377,7 @@ class ModLog(Cog, name="ModLog"): Icons.user_ban, Colours.soft_red, "User banned", f"{member} (`{member.id}`)", thumbnail=member.avatar_url_as(static_format="png"), - channel_id=Channels.userlog + channel_id=Channels.user_log ) @Cog.listener() @@ -399,7 +399,7 @@ class ModLog(Cog, name="ModLog"): Icons.sign_in, Colours.soft_green, "User joined", message, thumbnail=member.avatar_url_as(static_format="png"), - channel_id=Channels.userlog + channel_id=Channels.user_log ) @Cog.listener() @@ -416,7 +416,7 @@ class ModLog(Cog, name="ModLog"): Icons.sign_out, Colours.soft_red, "User left", f"{member} (`{member.id}`)", thumbnail=member.avatar_url_as(static_format="png"), - channel_id=Channels.userlog + channel_id=Channels.user_log ) @Cog.listener() @@ -433,7 +433,7 @@ class ModLog(Cog, name="ModLog"): Icons.user_unban, Colour.blurple(), "User unbanned", f"{member} (`{member.id}`)", thumbnail=member.avatar_url_as(static_format="png"), - channel_id=Channels.modlog + channel_id=Channels.mod_log ) @Cog.listener() @@ -529,7 +529,7 @@ class ModLog(Cog, name="ModLog"): Icons.user_update, Colour.blurple(), "Member updated", message, thumbnail=after.avatar_url_as(static_format="png"), - channel_id=Channels.userlog + channel_id=Channels.user_log ) @Cog.listener() @@ -538,7 +538,7 @@ class ModLog(Cog, name="ModLog"): channel = message.channel author = message.author - if message.guild.id != GuildConstant.id or channel.id in GuildConstant.ignored: + if message.guild.id != GuildConstant.id or channel.id in GuildConstant.modlog_blacklist: return self._cached_deletes.append(message.id) @@ -591,7 +591,7 @@ class ModLog(Cog, name="ModLog"): @Cog.listener() async def on_raw_message_delete(self, event: discord.RawMessageDeleteEvent) -> None: """Log raw message delete event to message change log.""" - if event.guild_id != GuildConstant.id or event.channel_id in GuildConstant.ignored: + if event.guild_id != GuildConstant.id or event.channel_id in GuildConstant.modlog_blacklist: return await asyncio.sleep(1) # Wait here in case the normal event was fired @@ -635,7 +635,7 @@ class ModLog(Cog, name="ModLog"): if ( not msg_before.guild or msg_before.guild.id != GuildConstant.id - or msg_before.channel.id in GuildConstant.ignored + or msg_before.channel.id in GuildConstant.modlog_blacklist or msg_before.author.bot ): return @@ -717,7 +717,7 @@ class ModLog(Cog, name="ModLog"): if ( not message.guild or message.guild.id != GuildConstant.id - or message.channel.id in GuildConstant.ignored + or message.channel.id in GuildConstant.modlog_blacklist or message.author.bot ): return @@ -769,7 +769,7 @@ class ModLog(Cog, name="ModLog"): """Log member voice state changes to the voice log channel.""" if ( member.guild.id != GuildConstant.id - or (before.channel and before.channel.id in GuildConstant.ignored) + or (before.channel and before.channel.id in GuildConstant.modlog_blacklist) ): return diff --git a/bot/cogs/moderation/scheduler.py b/bot/cogs/moderation/scheduler.py index e14c302cb..45cf5ec8a 100644 --- a/bot/cogs/moderation/scheduler.py +++ b/bot/cogs/moderation/scheduler.py @@ -38,7 +38,7 @@ class InfractionScheduler(Scheduler): async def reschedule_infractions(self, supported_infractions: t.Container[str]) -> None: """Schedule expiration for previous infractions.""" - await self.bot.wait_until_ready() + await self.bot.wait_until_guild_available() log.trace(f"Rescheduling infractions for {self.__class__.__name__}.") @@ -307,18 +307,25 @@ class InfractionScheduler(Scheduler): Infractions of unsupported types will raise a ValueError. """ guild = self.bot.get_guild(constants.Guild.id) - mod_role = guild.get_role(constants.Roles.moderator) + mod_role = guild.get_role(constants.Roles.moderators) user_id = infraction["user"] + actor = infraction["actor"] type_ = infraction["type"] id_ = infraction["id"] + inserted_at = infraction["inserted_at"] + expiry = infraction["expires_at"] log.info(f"Marking infraction #{id_} as inactive (expired).") + expiry = dateutil.parser.isoparse(expiry).replace(tzinfo=None) if expiry else None + created = time.format_infraction_with_duration(inserted_at, expiry) + log_content = None log_text = { - "Member": str(user_id), - "Actor": str(self.bot.user), - "Reason": infraction["reason"] + "Member": f"<@{user_id}>", + "Actor": str(self.bot.get_user(actor) or actor), + "Reason": infraction["reason"], + "Created": created, } try: @@ -384,14 +391,19 @@ class InfractionScheduler(Scheduler): if send_log: log_title = f"expiration failed" if "Failure" in log_text else "expired" + user = self.bot.get_user(user_id) + avatar = user.avatar_url_as(static_format="png") if user else None + log.trace(f"Sending deactivation mod log for infraction #{id_}.") await self.mod_log.send_log_message( icon_url=utils.INFRACTION_ICONS[type_][1], colour=Colours.soft_green, title=f"Infraction {log_title}: {type_}", + thumbnail=avatar, text="\n".join(f"{k}: {v}" for k, v in log_text.items()), footer=f"ID: {id_}", content=log_content, + ) return log_text diff --git a/bot/cogs/moderation/superstarify.py b/bot/cogs/moderation/superstarify.py index 050c847ac..c41874a95 100644 --- a/bot/cogs/moderation/superstarify.py +++ b/bot/cogs/moderation/superstarify.py @@ -109,7 +109,8 @@ class Superstarify(InfractionScheduler, Cog): ctx: Context, member: Member, duration: Expiry, - reason: str = None + *, + reason: str = None, ) -> None: """ Temporarily force a random superstar name (like Taylor Swift) to be the user's nickname. diff --git a/bot/cogs/off_topic_names.py b/bot/cogs/off_topic_names.py index bf777ea5a..81511f99d 100644 --- a/bot/cogs/off_topic_names.py +++ b/bot/cogs/off_topic_names.py @@ -88,7 +88,7 @@ class OffTopicNames(Cog): async def init_offtopic_updater(self) -> None: """Start off-topic channel updating event loop if it hasn't already started.""" - await self.bot.wait_until_ready() + await self.bot.wait_until_guild_available() if self.updater_task is None: coro = update_names(self.bot) self.updater_task = self.bot.loop.create_task(coro) diff --git a/bot/cogs/reddit.py b/bot/cogs/reddit.py index aa487f18e..5a7fa100f 100644 --- a/bot/cogs/reddit.py +++ b/bot/cogs/reddit.py @@ -43,12 +43,12 @@ class Reddit(Cog): def cog_unload(self) -> None: """Stop the loop task and revoke the access token when the cog is unloaded.""" self.auto_poster_loop.cancel() - if self.access_token.expires_at < datetime.utcnow(): - self.revoke_access_token() + if self.access_token and self.access_token.expires_at > datetime.utcnow(): + asyncio.create_task(self.revoke_access_token()) async def init_reddit_ready(self) -> None: """Sets the reddit webhook when the cog is loaded.""" - await self.bot.wait_until_ready() + await self.bot.wait_until_guild_available() if not self.webhook: self.webhook = await self.bot.fetch_webhook(Webhooks.reddit) @@ -83,7 +83,7 @@ class Reddit(Cog): expires_at=datetime.utcnow() + timedelta(seconds=expiration) ) - log.debug(f"New token acquired; expires on {self.access_token.expires_at}") + log.debug(f"New token acquired; expires on UTC {self.access_token.expires_at}") return else: log.debug( @@ -208,7 +208,7 @@ class Reddit(Cog): await asyncio.sleep(seconds_until) - await self.bot.wait_until_ready() + await self.bot.wait_until_guild_available() if not self.webhook: await self.bot.fetch_webhook(Webhooks.reddit) @@ -290,4 +290,7 @@ class Reddit(Cog): def setup(bot: Bot) -> None: """Load the Reddit cog.""" + if not RedditConfig.secret or not RedditConfig.client_id: + log.error("Credentials not provided, cog not loaded.") + return bot.add_cog(Reddit(bot)) diff --git a/bot/cogs/reminders.py b/bot/cogs/reminders.py index 45bf9a8f4..041791056 100644 --- a/bot/cogs/reminders.py +++ b/bot/cogs/reminders.py @@ -2,16 +2,17 @@ import asyncio import logging import random import textwrap +import typing as t from datetime import datetime, timedelta from operator import itemgetter -from typing import Optional +import discord +from dateutil.parser import isoparse from dateutil.relativedelta import relativedelta -from discord import Colour, Embed, Message from discord.ext.commands import Cog, Context, group from bot.bot import Bot -from bot.constants import Channels, Icons, NEGATIVE_REPLIES, POSITIVE_REPLIES, STAFF_ROLES +from bot.constants import Guild, Icons, NEGATIVE_REPLIES, POSITIVE_REPLIES, STAFF_ROLES from bot.converters import Duration from bot.pagination import LinePaginator from bot.utils.checks import without_role_check @@ -20,7 +21,7 @@ from bot.utils.time import humanize_delta, wait_until log = logging.getLogger(__name__) -WHITELISTED_CHANNELS = (Channels.bot,) +WHITELISTED_CHANNELS = Guild.reminder_whitelist MAXIMUM_REMINDERS = 5 @@ -35,7 +36,7 @@ class Reminders(Scheduler, Cog): async def reschedule_reminders(self) -> None: """Get all current reminders from the API and reschedule them.""" - await self.bot.wait_until_ready() + await self.bot.wait_until_guild_available() response = await self.bot.api_client.get( 'bot/reminders', params={'active': 'true'} @@ -45,29 +46,64 @@ class Reminders(Scheduler, Cog): loop = asyncio.get_event_loop() for reminder in response: - remind_at = datetime.fromisoformat(reminder['expiration'][:-1]) + is_valid, *_ = self.ensure_valid_reminder(reminder, cancel_task=False) + if not is_valid: + continue + + remind_at = isoparse(reminder['expiration']).replace(tzinfo=None) # If the reminder is already overdue ... if remind_at < now: late = relativedelta(now, remind_at) await self.send_reminder(reminder, late) - else: self.schedule_task(loop, reminder["id"], reminder) + def ensure_valid_reminder( + self, + reminder: dict, + cancel_task: bool = True + ) -> t.Tuple[bool, discord.User, discord.TextChannel]: + """Ensure reminder author and channel can be fetched otherwise delete the reminder.""" + user = self.bot.get_user(reminder['author']) + channel = self.bot.get_channel(reminder['channel_id']) + is_valid = True + if not user or not channel: + is_valid = False + log.info( + f"Reminder {reminder['id']} invalid: " + f"User {reminder['author']}={user}, Channel {reminder['channel_id']}={channel}." + ) + asyncio.create_task(self._delete_reminder(reminder['id'], cancel_task)) + + return is_valid, user, channel + @staticmethod - async def _send_confirmation(ctx: Context, on_success: str) -> None: + async def _send_confirmation( + ctx: Context, + on_success: str, + reminder_id: str, + delivery_dt: t.Optional[datetime], + ) -> None: """Send an embed confirming the reminder change was made successfully.""" - embed = Embed() - embed.colour = Colour.green() + embed = discord.Embed() + embed.colour = discord.Colour.green() embed.title = random.choice(POSITIVE_REPLIES) embed.description = on_success + + footer_str = f"ID: {reminder_id}" + if delivery_dt: + # Reminder deletion will have a `None` `delivery_dt` + footer_str = f"{footer_str}, Due: {delivery_dt.strftime('%Y-%m-%dT%H:%M:%S')}" + + embed.set_footer(text=footer_str) + await ctx.send(embed=embed) async def _scheduled_task(self, reminder: dict) -> None: """A coroutine which sends the reminder once the time is reached, and cancels the running task.""" reminder_id = reminder["id"] - reminder_datetime = datetime.fromisoformat(reminder['expiration'][:-1]) + reminder_datetime = isoparse(reminder['expiration']).replace(tzinfo=None) # Send the reminder message once the desired duration has passed await wait_until(reminder_datetime) @@ -79,12 +115,13 @@ class Reminders(Scheduler, Cog): # Now we can begone with it from our schedule list. self.cancel_task(reminder_id) - async def _delete_reminder(self, reminder_id: str) -> None: + async def _delete_reminder(self, reminder_id: str, cancel_task: bool = True) -> None: """Delete a reminder from the database, given its ID, and cancel the running task.""" await self.bot.api_client.delete('bot/reminders/' + str(reminder_id)) - # Now we can remove it from the schedule list - self.cancel_task(reminder_id) + if cancel_task: + # Now we can remove it from the schedule list + self.cancel_task(reminder_id) async def _reschedule_reminder(self, reminder: dict) -> None: """Reschedule a reminder object.""" @@ -95,11 +132,12 @@ class Reminders(Scheduler, Cog): async def send_reminder(self, reminder: dict, late: relativedelta = None) -> None: """Send the reminder.""" - channel = self.bot.get_channel(reminder["channel_id"]) - user = self.bot.get_user(reminder["author"]) + is_valid, user, channel = self.ensure_valid_reminder(reminder) + if not is_valid: + return - embed = Embed() - embed.colour = Colour.blurple() + embed = discord.Embed() + embed.colour = discord.Colour.blurple() embed.set_author( icon_url=Icons.remind_blurple, name="It has arrived!" @@ -111,7 +149,7 @@ class Reminders(Scheduler, Cog): embed.description += f"\n[Jump back to when you created the reminder]({reminder['jump_url']})" if late: - embed.colour = Colour.red() + embed.colour = discord.Colour.red() embed.set_author( icon_url=Icons.remind_red, name=f"Sorry it arrived {humanize_delta(late, max_units=2)} late!" @@ -129,20 +167,20 @@ class Reminders(Scheduler, Cog): await ctx.invoke(self.new_reminder, expiration=expiration, content=content) @remind_group.command(name="new", aliases=("add", "create")) - async def new_reminder(self, ctx: Context, expiration: Duration, *, content: str) -> Optional[Message]: + async def new_reminder(self, ctx: Context, expiration: Duration, *, content: str) -> t.Optional[discord.Message]: """ Set yourself a simple reminder. Expiration is parsed per: http://strftime.org/ """ - embed = Embed() + embed = discord.Embed() # If the user is not staff, we need to verify whether or not to make a reminder at all. if without_role_check(ctx, *STAFF_ROLES): # If they don't have permission to set a reminder in this channel if ctx.channel.id not in WHITELISTED_CHANNELS: - embed.colour = Colour.red() + embed.colour = discord.Colour.red() embed.title = random.choice(NEGATIVE_REPLIES) embed.description = "Sorry, you can't do that here!" @@ -159,7 +197,7 @@ class Reminders(Scheduler, Cog): # Let's limit this, so we don't get 10 000 # reminders from kip or something like that :P if len(active_reminders) > MAXIMUM_REMINDERS: - embed.colour = Colour.red() + embed.colour = discord.Colour.red() embed.title = random.choice(NEGATIVE_REPLIES) embed.description = "You have too many active reminders!" @@ -178,18 +216,21 @@ class Reminders(Scheduler, Cog): ) now = datetime.utcnow() - timedelta(seconds=1) + humanized_delta = humanize_delta(relativedelta(expiration, now)) # Confirm to the user that it worked. await self._send_confirmation( ctx, - on_success=f"Your reminder will arrive in {humanize_delta(relativedelta(expiration, now))}!" + on_success=f"Your reminder will arrive in {humanized_delta}!", + reminder_id=reminder["id"], + delivery_dt=expiration, ) loop = asyncio.get_event_loop() self.schedule_task(loop, reminder["id"], reminder) @remind_group.command(name="list") - async def list_reminders(self, ctx: Context) -> Optional[Message]: + async def list_reminders(self, ctx: Context) -> t.Optional[discord.Message]: """View a paginated embed of all reminders for your user.""" # Get all the user's reminders from the database. data = await self.bot.api_client.get( @@ -212,7 +253,7 @@ class Reminders(Scheduler, Cog): for content, remind_at, id_ in reminders: # Parse and humanize the time, make it pretty :D - remind_datetime = datetime.fromisoformat(remind_at[:-1]) + remind_datetime = isoparse(remind_at).replace(tzinfo=None) time = humanize_delta(relativedelta(remind_datetime, now)) text = textwrap.dedent(f""" @@ -222,8 +263,8 @@ class Reminders(Scheduler, Cog): lines.append(text) - embed = Embed() - embed.colour = Colour.blurple() + embed = discord.Embed() + embed.colour = discord.Colour.blurple() embed.title = f"Reminders for {ctx.author}" # Remind the user that they have no reminders :^) @@ -232,7 +273,7 @@ class Reminders(Scheduler, Cog): return await ctx.send(embed=embed) # Construct the embed and paginate it. - embed.colour = Colour.blurple() + embed.colour = discord.Colour.blurple() await LinePaginator.paginate( lines, @@ -261,7 +302,10 @@ class Reminders(Scheduler, Cog): # Send a confirmation message to the channel await self._send_confirmation( - ctx, on_success="That reminder has been edited successfully!" + ctx, + on_success="That reminder has been edited successfully!", + reminder_id=id_, + delivery_dt=expiration, ) await self._reschedule_reminder(reminder) @@ -275,18 +319,27 @@ class Reminders(Scheduler, Cog): json={'content': content} ) + # Parse the reminder expiration back into a datetime for the confirmation message + expiration = isoparse(reminder['expiration']).replace(tzinfo=None) + # Send a confirmation message to the channel await self._send_confirmation( - ctx, on_success="That reminder has been edited successfully!" + ctx, + on_success="That reminder has been edited successfully!", + reminder_id=id_, + delivery_dt=expiration, ) await self._reschedule_reminder(reminder) - @remind_group.command("delete", aliases=("remove",)) + @remind_group.command("delete", aliases=("remove", "cancel")) async def delete_reminder(self, ctx: Context, id_: int) -> None: """Delete one of your active reminders.""" await self._delete_reminder(id_) await self._send_confirmation( - ctx, on_success="That reminder has been deleted successfully!" + ctx, + on_success="That reminder has been deleted successfully!", + reminder_id=id_, + delivery_dt=None, ) diff --git a/bot/cogs/snekbox.py b/bot/cogs/snekbox.py index d52027ac6..44764e7e9 100644 --- a/bot/cogs/snekbox.py +++ b/bot/cogs/snekbox.py @@ -38,7 +38,7 @@ RAW_CODE_REGEX = re.compile( ) MAX_PASTE_LEN = 1000 -EVAL_ROLES = (Roles.helpers, Roles.moderator, Roles.admin, Roles.owner, Roles.rockstars, Roles.partners) +EVAL_ROLES = (Roles.helpers, Roles.moderators, Roles.admins, Roles.owners, Roles.python_community, Roles.partners) SIGKILL = 9 @@ -245,7 +245,7 @@ class Snekbox(Cog): @command(name="eval", aliases=("e",)) @guild_only() - @in_channel(Channels.bot, hidden_channels=(Channels.esoteric,), bypass_roles=EVAL_ROLES) + @in_channel(Channels.bot_commands, hidden_channels=(Channels.esoteric,), bypass_roles=EVAL_ROLES) async def eval_command(self, ctx: Context, *, code: str = None) -> None: """ Run Python code and get the results. diff --git a/bot/cogs/sync/cog.py b/bot/cogs/sync/cog.py index 4e6ed156b..5708be3f4 100644 --- a/bot/cogs/sync/cog.py +++ b/bot/cogs/sync/cog.py @@ -1,7 +1,7 @@ import logging -from typing import Callable, Dict, Iterable, Union +from typing import Any, Dict -from discord import Guild, Member, Role, User +from discord import Member, Role, User from discord.ext import commands from discord.ext.commands import Cog, Context @@ -16,45 +16,28 @@ log = logging.getLogger(__name__) class Sync(Cog): """Captures relevant events and sends them to the site.""" - # The server to synchronize events on. - # Note that setting this wrongly will result in things getting deleted - # that possibly shouldn't be. - SYNC_SERVER_ID = constants.Guild.id - - # An iterable of callables that are called when the bot is ready. - ON_READY_SYNCERS: Iterable[Callable[[Bot, Guild], None]] = ( - syncers.sync_roles, - syncers.sync_users - ) - def __init__(self, bot: Bot) -> None: self.bot = bot + self.role_syncer = syncers.RoleSyncer(self.bot) + self.user_syncer = syncers.UserSyncer(self.bot) self.bot.loop.create_task(self.sync_guild()) async def sync_guild(self) -> None: """Syncs the roles/users of the guild with the database.""" - await self.bot.wait_until_ready() - guild = self.bot.get_guild(self.SYNC_SERVER_ID) - if guild is not None: - for syncer in self.ON_READY_SYNCERS: - syncer_name = syncer.__name__[5:] # drop off `sync_` - log.info("Starting `%s` syncer.", syncer_name) - total_created, total_updated, total_deleted = await syncer(self.bot, guild) - if total_deleted is None: - log.info( - f"`{syncer_name}` syncer finished, created `{total_created}`, updated `{total_updated}`." - ) - else: - log.info( - f"`{syncer_name}` syncer finished, created `{total_created}`, updated `{total_updated}`, " - f"deleted `{total_deleted}`." - ) - - async def patch_user(self, user_id: int, updated_information: Dict[str, Union[str, int]]) -> None: + await self.bot.wait_until_guild_available() + + guild = self.bot.get_guild(constants.Guild.id) + if guild is None: + return + + for syncer in (self.role_syncer, self.user_syncer): + await syncer.sync(guild) + + async def patch_user(self, user_id: int, updated_information: Dict[str, Any]) -> None: """Send a PATCH request to partially update a user in the database.""" try: - await self.bot.api_client.patch("bot/users/" + str(user_id), json=updated_information) + await self.bot.api_client.patch(f"bot/users/{user_id}", json=updated_information) except ResponseCodeError as e: if e.response.status != 404: raise @@ -82,12 +65,14 @@ class Sync(Cog): @Cog.listener() async def on_guild_role_update(self, before: Role, after: Role) -> None: """Syncs role with the database if any of the stored attributes were updated.""" - if ( - before.name != after.name - or before.colour != after.colour - or before.permissions != after.permissions - or before.position != after.position - ): + was_updated = ( + before.name != after.name + or before.colour != after.colour + or before.permissions != after.permissions + or before.position != after.position + ) + + if was_updated: await self.bot.api_client.put( f'bot/roles/{after.id}', json={ @@ -137,18 +122,8 @@ class Sync(Cog): @Cog.listener() async def on_member_remove(self, member: Member) -> None: - """Updates the user information when a member leaves the guild.""" - await self.bot.api_client.put( - f'bot/users/{member.id}', - json={ - 'avatar_hash': member.avatar, - 'discriminator': int(member.discriminator), - 'id': member.id, - 'in_guild': False, - 'name': member.name, - 'roles': sorted(role.id for role in member.roles) - } - ) + """Set the in_guild field to False when a member leaves the guild.""" + await self.patch_user(member.id, updated_information={"in_guild": False}) @Cog.listener() async def on_member_update(self, before: Member, after: Member) -> None: @@ -160,7 +135,8 @@ class Sync(Cog): @Cog.listener() async def on_user_update(self, before: User, after: User) -> None: """Update the user information in the database if a relevant change is detected.""" - if any(getattr(before, attr) != getattr(after, attr) for attr in ("name", "discriminator", "avatar")): + attrs = ("name", "discriminator", "avatar") + if any(getattr(before, attr) != getattr(after, attr) for attr in attrs): updated_information = { "name": after.name, "discriminator": int(after.discriminator), @@ -176,25 +152,11 @@ class Sync(Cog): @sync_group.command(name='roles') @commands.has_permissions(administrator=True) async def sync_roles_command(self, ctx: Context) -> None: - """Manually synchronize the guild's roles with the roles on the site.""" - initial_response = await ctx.send("📊 Synchronizing roles.") - total_created, total_updated, total_deleted = await syncers.sync_roles(self.bot, ctx.guild) - await initial_response.edit( - content=( - f"👌 Role synchronization complete, created **{total_created}** " - f", updated **{total_created}** roles, and deleted **{total_deleted}** roles." - ) - ) + """Manually synchronise the guild's roles with the roles on the site.""" + await self.role_syncer.sync(ctx.guild, ctx) @sync_group.command(name='users') @commands.has_permissions(administrator=True) async def sync_users_command(self, ctx: Context) -> None: - """Manually synchronize the guild's users with the users on the site.""" - initial_response = await ctx.send("📊 Synchronizing users.") - total_created, total_updated, total_deleted = await syncers.sync_users(self.bot, ctx.guild) - await initial_response.edit( - content=( - f"👌 User synchronization complete, created **{total_created}** " - f"and updated **{total_created}** users." - ) - ) + """Manually synchronise the guild's users with the users on the site.""" + await self.user_syncer.sync(ctx.guild, ctx) diff --git a/bot/cogs/sync/syncers.py b/bot/cogs/sync/syncers.py index 14cf51383..d6891168f 100644 --- a/bot/cogs/sync/syncers.py +++ b/bot/cogs/sync/syncers.py @@ -1,235 +1,342 @@ +import abc +import logging +import typing as t from collections import namedtuple -from typing import Dict, Set, Tuple +from functools import partial -from discord import Guild +from discord import Guild, HTTPException, Member, Message, Reaction, User +from discord.ext.commands import Context +from bot import constants +from bot.api import ResponseCodeError from bot.bot import Bot +log = logging.getLogger(__name__) + # These objects are declared as namedtuples because tuples are hashable, # something that we make use of when diffing site roles against guild roles. -Role = namedtuple('Role', ('id', 'name', 'colour', 'permissions', 'position')) -User = namedtuple('User', ('id', 'name', 'discriminator', 'avatar_hash', 'roles', 'in_guild')) - - -def get_roles_for_sync( - guild_roles: Set[Role], api_roles: Set[Role] -) -> Tuple[Set[Role], Set[Role], Set[Role]]: - """ - Determine which roles should be created or updated on the site. - - Arguments: - guild_roles (Set[Role]): - Roles that were found on the guild at startup. - - api_roles (Set[Role]): - Roles that were retrieved from the API at startup. - - Returns: - Tuple[Set[Role], Set[Role]. Set[Role]]: - A tuple with three elements. The first element represents - roles to be created on the site, meaning that they were - present on the cached guild but not on the API. The second - element represents roles to be updated, meaning they were - present on both the cached guild and the API but non-ID - fields have changed inbetween. The third represents roles - to be deleted on the site, meaning the roles are present on - the API but not in the cached guild. - """ - guild_role_ids = {role.id for role in guild_roles} - api_role_ids = {role.id for role in api_roles} - new_role_ids = guild_role_ids - api_role_ids - deleted_role_ids = api_role_ids - guild_role_ids - - # New roles are those which are on the cached guild but not on the - # API guild, going by the role ID. We need to send them in for creation. - roles_to_create = {role for role in guild_roles if role.id in new_role_ids} - roles_to_update = guild_roles - api_roles - roles_to_create - roles_to_delete = {role for role in api_roles if role.id in deleted_role_ids} - return roles_to_create, roles_to_update, roles_to_delete - - -async def sync_roles(bot: Bot, guild: Guild) -> Tuple[int, int, int]: - """ - Synchronize roles found on the given `guild` with the ones on the API. - - Arguments: - bot (bot.bot.Bot): - The bot instance that we're running with. - - guild (discord.Guild): - The guild instance from the bot's cache - to synchronize roles with. - - Returns: - Tuple[int, int, int]: - A tuple with three integers representing how many roles were created - (element `0`) , how many roles were updated (element `1`), and how many - roles were deleted (element `2`) on the API. - """ - roles = await bot.api_client.get('bot/roles') - - # Pack API roles and guild roles into one common format, - # which is also hashable. We need hashability to be able - # to compare these easily later using sets. - api_roles = {Role(**role_dict) for role_dict in roles} - guild_roles = { - Role( - id=role.id, name=role.name, - colour=role.colour.value, permissions=role.permissions.value, - position=role.position, - ) - for role in guild.roles - } - roles_to_create, roles_to_update, roles_to_delete = get_roles_for_sync(guild_roles, api_roles) - - for role in roles_to_create: - await bot.api_client.post( - 'bot/roles', - json={ - 'id': role.id, - 'name': role.name, - 'colour': role.colour, - 'permissions': role.permissions, - 'position': role.position, - } - ) +_Role = namedtuple('Role', ('id', 'name', 'colour', 'permissions', 'position')) +_User = namedtuple('User', ('id', 'name', 'discriminator', 'avatar_hash', 'roles', 'in_guild')) +_Diff = namedtuple('Diff', ('created', 'updated', 'deleted')) - for role in roles_to_update: - await bot.api_client.put( - f'bot/roles/{role.id}', - json={ - 'id': role.id, - 'name': role.name, - 'colour': role.colour, - 'permissions': role.permissions, - 'position': role.position, - } - ) - for role in roles_to_delete: - await bot.api_client.delete(f'bot/roles/{role.id}') - - return len(roles_to_create), len(roles_to_update), len(roles_to_delete) - - -def get_users_for_sync( - guild_users: Dict[int, User], api_users: Dict[int, User] -) -> Tuple[Set[User], Set[User]]: - """ - Determine which users should be created or updated on the website. - - Arguments: - guild_users (Dict[int, User]): - A mapping of user IDs to user data, populated from the - guild cached on the running bot instance. - - api_users (Dict[int, User]): - A mapping of user IDs to user data, populated from the API's - current inventory of all users. - - Returns: - Tuple[Set[User], Set[User]]: - Two user sets as a tuple. The first element represents users - to be created on the website, these are users that are present - in the cached guild data but not in the API at all, going by - their ID. The second element represents users to update. It is - populated by users which are present on both the API and the - guild, but where the attribute of a user on the API is not - equal to the attribute of the user on the guild. - """ - users_to_create = set() - users_to_update = set() - - for api_user in api_users.values(): - guild_user = guild_users.get(api_user.id) - if guild_user is not None: - if api_user != guild_user: - users_to_update.add(guild_user) - - elif api_user.in_guild: - # The user is known on the API but not the guild, and the - # API currently specifies that the user is a member of the guild. - # This means that the user has left since the last sync. - # Update the `in_guild` attribute of the user on the site - # to signify that the user left. - new_api_user = api_user._replace(in_guild=False) - users_to_update.add(new_api_user) - - new_user_ids = set(guild_users.keys()) - set(api_users.keys()) - for user_id in new_user_ids: - # The user is known on the guild but not on the API. This means - # that the user has joined since the last sync. Create it. - new_user = guild_users[user_id] - users_to_create.add(new_user) - - return users_to_create, users_to_update - - -async def sync_users(bot: Bot, guild: Guild) -> Tuple[int, int, None]: - """ - Synchronize users found in the given `guild` with the ones in the API. - - Arguments: - bot (bot.bot.Bot): - The bot instance that we're running with. - - guild (discord.Guild): - The guild instance from the bot's cache - to synchronize roles with. - - Returns: - Tuple[int, int, None]: - A tuple with two integers, representing how many users were created - (element `0`) and how many users were updated (element `1`), and `None` - to indicate that a user sync never deletes entries from the API. - """ - current_users = await bot.api_client.get('bot/users') - - # Pack API users and guild users into one common format, - # which is also hashable. We need hashability to be able - # to compare these easily later using sets. - api_users = { - user_dict['id']: User( - roles=tuple(sorted(user_dict.pop('roles'))), - **user_dict - ) - for user_dict in current_users - } - guild_users = { - member.id: User( - id=member.id, name=member.name, - discriminator=int(member.discriminator), avatar_hash=member.avatar, - roles=tuple(sorted(role.id for role in member.roles)), in_guild=True - ) - for member in guild.members - } - - users_to_create, users_to_update = get_users_for_sync(guild_users, api_users) - - for user in users_to_create: - await bot.api_client.post( - 'bot/users', - json={ - 'avatar_hash': user.avatar_hash, - 'discriminator': user.discriminator, - 'id': user.id, - 'in_guild': user.in_guild, - 'name': user.name, - 'roles': list(user.roles) - } +class Syncer(abc.ABC): + """Base class for synchronising the database with objects in the Discord cache.""" + + _CORE_DEV_MENTION = f"<@&{constants.Roles.core_developers}> " + _REACTION_EMOJIS = (constants.Emojis.check_mark, constants.Emojis.cross_mark) + + def __init__(self, bot: Bot) -> None: + self.bot = bot + + @property + @abc.abstractmethod + def name(self) -> str: + """The name of the syncer; used in output messages and logging.""" + raise NotImplementedError # pragma: no cover + + async def _send_prompt(self, message: t.Optional[Message] = None) -> t.Optional[Message]: + """ + Send a prompt to confirm or abort a sync using reactions and return the sent message. + + If a message is given, it is edited to display the prompt and reactions. Otherwise, a new + message is sent to the dev-core channel and mentions the core developers role. If the + channel cannot be retrieved, return None. + """ + log.trace(f"Sending {self.name} sync confirmation prompt.") + + msg_content = ( + f'Possible cache issue while syncing {self.name}s. ' + f'More than {constants.Sync.max_diff} {self.name}s were changed. ' + f'React to confirm or abort the sync.' ) - for user in users_to_update: - await bot.api_client.put( - f'bot/users/{user.id}', - json={ - 'avatar_hash': user.avatar_hash, - 'discriminator': user.discriminator, - 'id': user.id, - 'in_guild': user.in_guild, - 'name': user.name, - 'roles': list(user.roles) - } + # Send to core developers if it's an automatic sync. + if not message: + log.trace("Message not provided for confirmation; creating a new one in dev-core.") + channel = self.bot.get_channel(constants.Channels.dev_core) + + if not channel: + log.debug("Failed to get the dev-core channel from cache; attempting to fetch it.") + try: + channel = await self.bot.fetch_channel(constants.Channels.dev_core) + except HTTPException: + log.exception( + f"Failed to fetch channel for sending sync confirmation prompt; " + f"aborting {self.name} sync." + ) + return None + + message = await channel.send(f"{self._CORE_DEV_MENTION}{msg_content}") + else: + await message.edit(content=msg_content) + + # Add the initial reactions. + log.trace(f"Adding reactions to {self.name} syncer confirmation prompt.") + for emoji in self._REACTION_EMOJIS: + await message.add_reaction(emoji) + + return message + + def _reaction_check( + self, + author: Member, + message: Message, + reaction: Reaction, + user: t.Union[Member, User] + ) -> bool: + """ + Return True if the `reaction` is a valid confirmation or abort reaction on `message`. + + If the `author` of the prompt is a bot, then a reaction by any core developer will be + considered valid. Otherwise, the author of the reaction (`user`) will have to be the + `author` of the prompt. + """ + # For automatic syncs, check for the core dev role instead of an exact author + has_role = any(constants.Roles.core_developers == role.id for role in user.roles) + return ( + reaction.message.id == message.id + and not user.bot + and (has_role if author.bot else user == author) + and str(reaction.emoji) in self._REACTION_EMOJIS ) - return len(users_to_create), len(users_to_update), None + async def _wait_for_confirmation(self, author: Member, message: Message) -> bool: + """ + Wait for a confirmation reaction by `author` on `message` and return True if confirmed. + + Uses the `_reaction_check` function to determine if a reaction is valid. + + If there is no reaction within `bot.constants.Sync.confirm_timeout` seconds, return False. + To acknowledge the reaction (or lack thereof), `message` will be edited. + """ + # Preserve the core-dev role mention in the message edits so users aren't confused about + # where notifications came from. + mention = self._CORE_DEV_MENTION if author.bot else "" + + reaction = None + try: + log.trace(f"Waiting for a reaction to the {self.name} syncer confirmation prompt.") + reaction, _ = await self.bot.wait_for( + 'reaction_add', + check=partial(self._reaction_check, author, message), + timeout=constants.Sync.confirm_timeout + ) + except TimeoutError: + # reaction will remain none thus sync will be aborted in the finally block below. + log.debug(f"The {self.name} syncer confirmation prompt timed out.") + finally: + if str(reaction) == constants.Emojis.check_mark: + log.trace(f"The {self.name} syncer was confirmed.") + await message.edit(content=f':ok_hand: {mention}{self.name} sync will proceed.') + return True + else: + log.warning(f"The {self.name} syncer was aborted or timed out!") + await message.edit( + content=f':warning: {mention}{self.name} sync aborted or timed out!' + ) + return False + + @abc.abstractmethod + async def _get_diff(self, guild: Guild) -> _Diff: + """Return the difference between the cache of `guild` and the database.""" + raise NotImplementedError # pragma: no cover + + @abc.abstractmethod + async def _sync(self, diff: _Diff) -> None: + """Perform the API calls for synchronisation.""" + raise NotImplementedError # pragma: no cover + + async def _get_confirmation_result( + self, + diff_size: int, + author: Member, + message: t.Optional[Message] = None + ) -> t.Tuple[bool, t.Optional[Message]]: + """ + Prompt for confirmation and return a tuple of the result and the prompt message. + + `diff_size` is the size of the diff of the sync. If it is greater than + `bot.constants.Sync.max_diff`, the prompt will be sent. The `author` is the invoked of the + sync and the `message` is an extant message to edit to display the prompt. + + If confirmed or no confirmation was needed, the result is True. The returned message will + either be the given `message` or a new one which was created when sending the prompt. + """ + log.trace(f"Determining if confirmation prompt should be sent for {self.name} syncer.") + if diff_size > constants.Sync.max_diff: + message = await self._send_prompt(message) + if not message: + return False, None # Couldn't get channel. + + confirmed = await self._wait_for_confirmation(author, message) + if not confirmed: + return False, message # Sync aborted. + + return True, message + + async def sync(self, guild: Guild, ctx: t.Optional[Context] = None) -> None: + """ + Synchronise the database with the cache of `guild`. + + If the differences between the cache and the database are greater than + `bot.constants.Sync.max_diff`, then a confirmation prompt will be sent to the dev-core + channel. The confirmation can be optionally redirect to `ctx` instead. + """ + log.info(f"Starting {self.name} syncer.") + + message = None + author = self.bot.user + if ctx: + message = await ctx.send(f"📊 Synchronising {self.name}s.") + author = ctx.author + + diff = await self._get_diff(guild) + diff_dict = diff._asdict() # Ugly method for transforming the NamedTuple into a dict + totals = {k: len(v) for k, v in diff_dict.items() if v is not None} + diff_size = sum(totals.values()) + + confirmed, message = await self._get_confirmation_result(diff_size, author, message) + if not confirmed: + return + + # Preserve the core-dev role mention in the message edits so users aren't confused about + # where notifications came from. + mention = self._CORE_DEV_MENTION if author.bot else "" + + try: + await self._sync(diff) + except ResponseCodeError as e: + log.exception(f"{self.name} syncer failed!") + + # Don't show response text because it's probably some really long HTML. + results = f"status {e.status}\n```{e.response_json or 'See log output for details'}```" + content = f":x: {mention}Synchronisation of {self.name}s failed: {results}" + else: + results = ", ".join(f"{name} `{total}`" for name, total in totals.items()) + log.info(f"{self.name} syncer finished: {results}.") + content = f":ok_hand: {mention}Synchronisation of {self.name}s complete: {results}" + + if message: + await message.edit(content=content) + + +class RoleSyncer(Syncer): + """Synchronise the database with roles in the cache.""" + + name = "role" + + async def _get_diff(self, guild: Guild) -> _Diff: + """Return the difference of roles between the cache of `guild` and the database.""" + log.trace("Getting the diff for roles.") + roles = await self.bot.api_client.get('bot/roles') + + # Pack DB roles and guild roles into one common, hashable format. + # They're hashable so that they're easily comparable with sets later. + db_roles = {_Role(**role_dict) for role_dict in roles} + guild_roles = { + _Role( + id=role.id, + name=role.name, + colour=role.colour.value, + permissions=role.permissions.value, + position=role.position, + ) + for role in guild.roles + } + + guild_role_ids = {role.id for role in guild_roles} + api_role_ids = {role.id for role in db_roles} + new_role_ids = guild_role_ids - api_role_ids + deleted_role_ids = api_role_ids - guild_role_ids + + # New roles are those which are on the cached guild but not on the + # DB guild, going by the role ID. We need to send them in for creation. + roles_to_create = {role for role in guild_roles if role.id in new_role_ids} + roles_to_update = guild_roles - db_roles - roles_to_create + roles_to_delete = {role for role in db_roles if role.id in deleted_role_ids} + + return _Diff(roles_to_create, roles_to_update, roles_to_delete) + + async def _sync(self, diff: _Diff) -> None: + """Synchronise the database with the role cache of `guild`.""" + log.trace("Syncing created roles...") + for role in diff.created: + await self.bot.api_client.post('bot/roles', json=role._asdict()) + + log.trace("Syncing updated roles...") + for role in diff.updated: + await self.bot.api_client.put(f'bot/roles/{role.id}', json=role._asdict()) + + log.trace("Syncing deleted roles...") + for role in diff.deleted: + await self.bot.api_client.delete(f'bot/roles/{role.id}') + + +class UserSyncer(Syncer): + """Synchronise the database with users in the cache.""" + + name = "user" + + async def _get_diff(self, guild: Guild) -> _Diff: + """Return the difference of users between the cache of `guild` and the database.""" + log.trace("Getting the diff for users.") + users = await self.bot.api_client.get('bot/users') + + # Pack DB roles and guild roles into one common, hashable format. + # They're hashable so that they're easily comparable with sets later. + db_users = { + user_dict['id']: _User( + roles=tuple(sorted(user_dict.pop('roles'))), + **user_dict + ) + for user_dict in users + } + guild_users = { + member.id: _User( + id=member.id, + name=member.name, + discriminator=int(member.discriminator), + avatar_hash=member.avatar, + roles=tuple(sorted(role.id for role in member.roles)), + in_guild=True + ) + for member in guild.members + } + + users_to_create = set() + users_to_update = set() + + for db_user in db_users.values(): + guild_user = guild_users.get(db_user.id) + if guild_user is not None: + if db_user != guild_user: + users_to_update.add(guild_user) + + elif db_user.in_guild: + # The user is known in the DB but not the guild, and the + # DB currently specifies that the user is a member of the guild. + # This means that the user has left since the last sync. + # Update the `in_guild` attribute of the user on the site + # to signify that the user left. + new_api_user = db_user._replace(in_guild=False) + users_to_update.add(new_api_user) + + new_user_ids = set(guild_users.keys()) - set(db_users.keys()) + for user_id in new_user_ids: + # The user is known on the guild but not on the API. This means + # that the user has joined since the last sync. Create it. + new_user = guild_users[user_id] + users_to_create.add(new_user) + + return _Diff(users_to_create, users_to_update, None) + + async def _sync(self, diff: _Diff) -> None: + """Synchronise the database with the user cache of `guild`.""" + log.trace("Syncing created users...") + for user in diff.created: + await self.bot.api_client.post('bot/users', json=user._asdict()) + + log.trace("Syncing updated users...") + for user in diff.updated: + await self.bot.api_client.put(f'bot/users/{user.id}', json=user._asdict()) diff --git a/bot/cogs/tags.py b/bot/cogs/tags.py index 54a51921c..5da9a4148 100644 --- a/bot/cogs/tags.py +++ b/bot/cogs/tags.py @@ -15,8 +15,7 @@ from bot.pagination import LinePaginator log = logging.getLogger(__name__) TEST_CHANNELS = ( - Channels.devtest, - Channels.bot, + Channels.bot_commands, Channels.helpers ) @@ -116,8 +115,10 @@ class Tags(Cog): if _command_on_cooldown(tag_name): time_left = Cooldowns.tags - (time.time() - self.tag_cooldowns[tag_name]["time"]) - log.warning(f"{ctx.author} tried to get the '{tag_name}' tag, but the tag is on cooldown. " - f"Cooldown ends in {time_left:.1f} seconds.") + log.info( + f"{ctx.author} tried to get the '{tag_name}' tag, but the tag is on cooldown. " + f"Cooldown ends in {time_left:.1f} seconds." + ) return await self._get_tags() @@ -219,7 +220,7 @@ class Tags(Cog): )) @tags_group.command(name='delete', aliases=('remove', 'rm', 'd')) - @with_role(Roles.admin, Roles.owner) + @with_role(Roles.admins, Roles.owners) async def delete_command(self, ctx: Context, *, tag_name: TagNameConverter) -> None: """Remove a tag from the database.""" await self.bot.api_client.delete(f'bot/tags/{tag_name}') diff --git a/bot/cogs/utils.py b/bot/cogs/utils.py index da278011a..94b9d6b5a 100644 --- a/bot/cogs/utils.py +++ b/bot/cogs/utils.py @@ -89,7 +89,7 @@ class Utils(Cog): await ctx.message.channel.send(embed=pep_embed) @command() - @in_channel(Channels.bot, bypass_roles=STAFF_ROLES) + @in_channel(Channels.bot_commands, bypass_roles=STAFF_ROLES) async def charinfo(self, ctx: Context, *, characters: str) -> None: """Shows you information on up to 25 unicode characters.""" match = re.match(r"<(a?):(\w+):(\d+)>", characters) diff --git a/bot/cogs/verification.py b/bot/cogs/verification.py index 988e0d49a..57b50c34f 100644 --- a/bot/cogs/verification.py +++ b/bot/cogs/verification.py @@ -1,7 +1,8 @@ import logging +from contextlib import suppress from datetime import datetime -from discord import Colour, Message, NotFound, Object +from discord import Colour, Forbidden, Message, NotFound, Object from discord.ext import tasks from discord.ext.commands import Cog, Context, command @@ -29,15 +30,16 @@ your information removed here as well. Feel free to review them at any point! Additionally, if you'd like to receive notifications for the announcements we post in <#{Channels.announcements}> \ -from time to time, you can send `!subscribe` to <#{Channels.bot}> at any time to assign yourself the \ +from time to time, you can send `!subscribe` to <#{Channels.bot_commands}> at any time to assign yourself the \ **Announcements** role. We'll mention this role every time we make an announcement. -If you'd like to unsubscribe from the announcement notifications, simply send `!unsubscribe` to <#{Channels.bot}>. +If you'd like to unsubscribe from the announcement notifications, simply send `!unsubscribe` to \ +<#{Channels.bot_commands}>. """ PERIODIC_PING = ( f"@everyone To verify that you have read our rules, please type `{BotConfig.prefix}accept`." - f" If you encounter any problems during the verification process, ping the <@&{Roles.admin}> role in this channel." + f" If you encounter any problems during the verification process, ping the <@&{Roles.admins}> role in this channel." ) BOT_MESSAGE_DELETE_DELAY = 10 @@ -92,19 +94,21 @@ class Verification(Cog): ping_everyone=Filter.ping_everyone, ) - ctx = await self.bot.get_context(message) # type: Context - + ctx: Context = await self.bot.get_context(message) if ctx.command is not None and ctx.command.name == "accept": - return # They used the accept command + return - for role in ctx.author.roles: - if role.id == Roles.verified: - log.warning(f"{ctx.author} posted '{ctx.message.content}' " - "in the verification channel, but is already verified.") - return # They're already verified + if any(r.id == Roles.verified for r in ctx.author.roles): + log.info( + f"{ctx.author} posted '{ctx.message.content}' " + "in the verification channel, but is already verified." + ) + return - log.debug(f"{ctx.author} posted '{ctx.message.content}' in the verification " - "channel. We are providing instructions how to verify.") + log.debug( + f"{ctx.author} posted '{ctx.message.content}' in the verification " + "channel. We are providing instructions how to verify." + ) await ctx.send( f"{ctx.author.mention} Please type `!accept` to verify that you accept our rules, " f"and gain access to the rest of the server.", @@ -112,11 +116,8 @@ class Verification(Cog): ) log.trace(f"Deleting the message posted by {ctx.author}") - - try: + with suppress(NotFound): await ctx.message.delete() - except NotFound: - log.trace("No message found, it must have been deleted by another bot.") @command(name='accept', aliases=('verify', 'verified', 'accepted'), hidden=True) @without_role(Roles.verified) @@ -127,20 +128,16 @@ class Verification(Cog): await ctx.author.add_roles(Object(Roles.verified), reason="Accepted the rules") try: await ctx.author.send(WELCOME_MESSAGE) - except Exception: - # Catch the exception, in case they have DMs off or something - log.exception(f"Unable to send welcome message to user {ctx.author}.") - - log.trace(f"Deleting the message posted by {ctx.author}.") - - try: - self.mod_log.ignore(Event.message_delete, ctx.message.id) - await ctx.message.delete() - except NotFound: - log.trace("No message found, it must have been deleted by another bot.") + except Forbidden: + log.info(f"Sending welcome message failed for {ctx.author}.") + finally: + log.trace(f"Deleting accept message by {ctx.author}.") + with suppress(NotFound): + self.mod_log.ignore(Event.message_delete, ctx.message.id) + await ctx.message.delete() @command(name='subscribe') - @in_channel(Channels.bot) + @in_channel(Channels.bot_commands) async def subscribe_command(self, ctx: Context, *_) -> None: # We don't actually care about the args """Subscribe to announcement notifications by assigning yourself the role.""" has_role = False @@ -164,7 +161,7 @@ class Verification(Cog): ) @command(name='unsubscribe') - @in_channel(Channels.bot) + @in_channel(Channels.bot_commands) async def unsubscribe_command(self, ctx: Context, *_) -> None: # We don't actually care about the args """Unsubscribe from announcement notifications by removing the role from yourself.""" has_role = False @@ -223,7 +220,7 @@ class Verification(Cog): @periodic_ping.before_loop async def before_ping(self) -> None: """Only start the loop when the bot is ready.""" - await self.bot.wait_until_ready() + await self.bot.wait_until_guild_available() def cog_unload(self) -> None: """Cancel the periodic ping task when the cog is unloaded.""" diff --git a/bot/cogs/watchchannels/watchchannel.py b/bot/cogs/watchchannels/watchchannel.py index eb787b083..3667a80e8 100644 --- a/bot/cogs/watchchannels/watchchannel.py +++ b/bot/cogs/watchchannels/watchchannel.py @@ -91,7 +91,7 @@ class WatchChannel(metaclass=CogABCMeta): async def start_watchchannel(self) -> None: """Starts the watch channel by getting the channel, webhook, and user cache ready.""" - await self.bot.wait_until_ready() + await self.bot.wait_until_guild_available() try: self.channel = await self.bot.fetch_channel(self.destination) diff --git a/bot/constants.py b/bot/constants.py index fe8e57322..14f8dc094 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -186,6 +186,11 @@ class YAMLGetter(type): def __getitem__(cls, name): return cls.__getattr__(name) + def __iter__(cls): + """Return generator of key: value pairs of current constants class' config values.""" + for name in cls.__annotations__: + yield name, getattr(cls, name) + # Dataclasses class Bot(metaclass=YAMLGetter): @@ -193,7 +198,7 @@ class Bot(metaclass=YAMLGetter): prefix: str token: str - + sentry_dsn: str class Filter(metaclass=YAMLGetter): section = "filter" @@ -263,6 +268,7 @@ class Emojis(metaclass=YAMLGetter): new: str pencil: str cross_mark: str + check_mark: str ducky_yellow: int ducky_blurple: int @@ -357,16 +363,16 @@ class Channels(metaclass=YAMLGetter): section = "guild" subsection = "channels" - admins: int admin_spam: int + admins: int announcements: int attachment_log: int big_brother_logs: int - bot: int - checkpoint_test: int + bot_commands: int defcon: int - devlog: int - devtest: int + dev_contrib: int + dev_core: int + dev_log: int esoteric: int help_0: int help_1: int @@ -379,19 +385,19 @@ class Channels(metaclass=YAMLGetter): helpers: int message_log: int meta: int + mod_alerts: int + mod_log: int mod_spam: int mods: int - mod_alerts: int - modlog: int off_topic_0: int off_topic_1: int off_topic_2: int organisation: int - python: int + python_discussion: int reddit: int talent_pool: int - userlog: int - user_event_a: int + user_event_announcements: int + user_log: int verification: int voice_log: int @@ -404,25 +410,25 @@ class Webhooks(metaclass=YAMLGetter): big_brother: int reddit: int duck_pond: int + dev_log: int class Roles(metaclass=YAMLGetter): section = "guild" subsection = "roles" - admin: int + admins: int announcements: int - champion: int - contributor: int - core_developer: int + contributors: int + core_developers: int helpers: int - jammer: int - moderator: int + jammers: int + moderators: int muted: int - owner: int + owners: int partners: int - rockstars: int - team_leader: int + python_community: int + team_leaders: int verified: int # This is the Developers role on PyDis, here named verified for readability reasons. @@ -430,9 +436,12 @@ class Guild(metaclass=YAMLGetter): section = "guild" id: int - ignored: List[int] + moderation_channels: List[int] + moderation_roles: List[int] + modlog_blacklist: List[int] + reminder_whitelist: List[int] staff_channels: List[int] - + staff_roles: List[int] class Keys(metaclass=YAMLGetter): section = "keys" @@ -537,6 +546,13 @@ class RedirectOutput(metaclass=YAMLGetter): delete_delay: int +class Sync(metaclass=YAMLGetter): + section = 'sync' + + confirm_timeout: int + max_diff: int + + class Event(Enum): """ Event names. This does not include every event (for example, raw @@ -571,14 +587,14 @@ BOT_DIR = os.path.dirname(__file__) PROJECT_ROOT = os.path.abspath(os.path.join(BOT_DIR, os.pardir)) # Default role combinations -MODERATION_ROLES = Roles.moderator, Roles.admin, Roles.owner -STAFF_ROLES = Roles.helpers, Roles.moderator, Roles.admin, Roles.owner +MODERATION_ROLES = Guild.moderation_roles +STAFF_ROLES = Guild.staff_roles # Roles combinations STAFF_CHANNELS = Guild.staff_channels # Default Channel combinations -MODERATION_CHANNELS = Channels.admins, Channels.admin_spam, Channels.mod_alerts, Channels.mods, Channels.mod_spam +MODERATION_CHANNELS = Guild.moderation_channels # Bot replies diff --git a/bot/pagination.py b/bot/pagination.py index e82763912..90c8f849c 100644 --- a/bot/pagination.py +++ b/bot/pagination.py @@ -1,8 +1,9 @@ import asyncio import logging -from typing import Iterable, List, Optional, Tuple +import typing as t +from contextlib import suppress -from discord import Embed, Member, Message, Reaction +import discord from discord.abc import User from discord.ext.commands import Context, Paginator @@ -14,7 +15,7 @@ RIGHT_EMOJI = "\u27A1" # [:arrow_right:] LAST_EMOJI = "\u23ED" # [:track_next:] DELETE_EMOJI = constants.Emojis.trashcan # [:trashcan:] -PAGINATION_EMOJI = [FIRST_EMOJI, LEFT_EMOJI, RIGHT_EMOJI, LAST_EMOJI, DELETE_EMOJI] +PAGINATION_EMOJI = (FIRST_EMOJI, LEFT_EMOJI, RIGHT_EMOJI, LAST_EMOJI, DELETE_EMOJI) log = logging.getLogger(__name__) @@ -89,12 +90,12 @@ class LinePaginator(Paginator): @classmethod async def paginate( cls, - lines: Iterable[str], + lines: t.List[str], ctx: Context, - embed: Embed, + embed: discord.Embed, prefix: str = "", suffix: str = "", - max_lines: Optional[int] = None, + max_lines: t.Optional[int] = None, max_size: int = 500, empty: bool = True, restrict_to_user: User = None, @@ -102,7 +103,7 @@ class LinePaginator(Paginator): footer_text: str = None, url: str = None, exception_on_empty_embed: bool = False - ) -> Optional[Message]: + ) -> t.Optional[discord.Message]: """ Use a paginator and set of reactions to provide pagination over a set of lines. @@ -114,11 +115,11 @@ class LinePaginator(Paginator): Pagination will also be removed automatically if no reaction is added for five minutes (300 seconds). Example: - >>> embed = Embed() + >>> embed = discord.Embed() >>> embed.set_author(name="Some Operation", url=url, icon_url=icon) - >>> await LinePaginator.paginate((line for line in lines), ctx, embed) + >>> await LinePaginator.paginate([line for line in lines], ctx, embed) """ - def event_check(reaction_: Reaction, user_: Member) -> bool: + def event_check(reaction_: discord.Reaction, user_: discord.Member) -> bool: """Make sure that this reaction is what we want to operate on.""" no_restrictions = ( # Pagination is not restricted @@ -281,8 +282,9 @@ class LinePaginator(Paginator): await message.edit(embed=embed) - log.debug("Ending pagination and removing all reactions...") - await message.clear_reactions() + log.debug("Ending pagination and clearing reactions.") + with suppress(discord.NotFound): + await message.clear_reactions() class ImagePaginator(Paginator): @@ -299,6 +301,7 @@ class ImagePaginator(Paginator): self._current_page = [prefix] self.images = [] self._pages = [] + self._count = 0 def add_line(self, line: str = '', *, empty: bool = False) -> None: """Adds a line to each page.""" @@ -316,13 +319,13 @@ class ImagePaginator(Paginator): @classmethod async def paginate( cls, - pages: List[Tuple[str, str]], - ctx: Context, embed: Embed, + pages: t.List[t.Tuple[str, str]], + ctx: Context, embed: discord.Embed, prefix: str = "", suffix: str = "", timeout: int = 300, exception_on_empty_embed: bool = False - ) -> Optional[Message]: + ) -> t.Optional[discord.Message]: """ Use a paginator and set of reactions to provide pagination over a set of title/image pairs. @@ -334,11 +337,11 @@ class ImagePaginator(Paginator): Note: Pagination will be removed automatically if no reaction is added for five minutes (300 seconds). Example: - >>> embed = Embed() + >>> embed = discord.Embed() >>> embed.set_author(name="Some Operation", url=url, icon_url=icon) >>> await ImagePaginator.paginate(pages, ctx, embed) """ - def check_event(reaction_: Reaction, member: Member) -> bool: + def check_event(reaction_: discord.Reaction, member: discord.Member) -> bool: """Checks each reaction added, if it matches our conditions pass the wait_for.""" return all(( # Reaction is on the same message sent @@ -445,5 +448,6 @@ class ImagePaginator(Paginator): await message.edit(embed=embed) - log.debug("Ending pagination and removing all reactions...") - await message.clear_reactions() + log.debug("Ending pagination and clearing reactions.") + with suppress(discord.NotFound): + await message.clear_reactions() diff --git a/bot/rules/attachments.py b/bot/rules/attachments.py index 00bb2a949..8903c385c 100644 --- a/bot/rules/attachments.py +++ b/bot/rules/attachments.py @@ -19,7 +19,7 @@ async def apply( if total_recent_attachments > config['max']: return ( - f"sent {total_recent_attachments} attachments in {config['max']}s", + f"sent {total_recent_attachments} attachments in {config['interval']}s", (last_message.author,), relevant_messages ) diff --git a/bot/utils/__init__.py b/bot/utils/__init__.py index 8184be824..3e4b15ce4 100644 --- a/bot/utils/__init__.py +++ b/bot/utils/__init__.py @@ -1,5 +1,5 @@ from abc import ABCMeta -from typing import Any, Generator, Hashable, Iterable +from typing import Any, Hashable from discord.ext.commands import CogMeta @@ -64,13 +64,3 @@ class CaseInsensitiveDict(dict): for k in list(self.keys()): v = super(CaseInsensitiveDict, self).pop(k) self.__setitem__(k, v) - - -def chunks(iterable: Iterable, size: int) -> Generator[Any, None, None]: - """ - Generator that allows you to iterate over any indexable collection in `size`-length chunks. - - Found: https://stackoverflow.com/a/312464/4022104 - """ - for i in range(0, len(iterable), size): - yield iterable[i:i + size] diff --git a/bot/utils/time.py b/bot/utils/time.py index 7416f36e0..77060143c 100644 --- a/bot/utils/time.py +++ b/bot/utils/time.py @@ -114,30 +114,40 @@ def format_infraction(timestamp: str) -> str: def format_infraction_with_duration( - expiry: Optional[str], + date_to: Optional[str], date_from: Optional[datetime.datetime] = None, - max_units: int = 2 + max_units: int = 2, + absolute: bool = True ) -> Optional[str]: """ - Format an infraction timestamp to a more readable ISO 8601 format WITH the duration. + Return `date_to` formatted as a readable ISO-8601 with the humanized duration since `date_from`. - Returns a human-readable version of the duration between datetime.utcnow() and an expiry. - Unlike `humanize_delta`, this function will force the `precision` to be `seconds` by not passing it. - `max_units` specifies the maximum number of units of time to include (e.g. 1 may include days but not hours). - By default, max_units is 2. + `date_from` must be an ISO-8601 formatted timestamp. The duration is calculated as from + `date_from` until `date_to` with a precision of seconds. If `date_from` is unspecified, the + current time is used. + + `max_units` specifies the maximum number of units of time to include in the duration. For + example, a value of 1 may include days but not hours. + + If `absolute` is True, the absolute value of the duration delta is used. This prevents negative + values in the case that `date_to` is in the past relative to `date_from`. """ - if not expiry: + if not date_to: return None + date_to_formatted = format_infraction(date_to) + date_from = date_from or datetime.datetime.utcnow() - date_to = dateutil.parser.isoparse(expiry).replace(tzinfo=None, microsecond=0) + date_to = dateutil.parser.isoparse(date_to).replace(tzinfo=None, microsecond=0) - expiry_formatted = format_infraction(expiry) + delta = relativedelta(date_to, date_from) + if absolute: + delta = abs(delta) - duration = humanize_delta(relativedelta(date_to, date_from), max_units=max_units) - duration_formatted = f" ({duration})" if duration else '' + duration = humanize_delta(delta, max_units=max_units) + duration_formatted = f" ({duration})" if duration else "" - return f"{expiry_formatted}{duration_formatted}" + return f"{date_to_formatted}{duration_formatted}" def until_expiration( diff --git a/config-default.yml b/config-default.yml index 3345e6f2a..ab237423f 100644 --- a/config-default.yml +++ b/config-default.yml @@ -1,6 +1,7 @@ bot: prefix: "!" token: !ENV "BOT_TOKEN" + sentry_dsn: !ENV "BOT_SENTRY_DSN" cooldowns: # Per channel, per tag. @@ -34,6 +35,7 @@ style: pencil: "\u270F" new: "\U0001F195" cross_mark: "\u274C" + check_mark: "\u2705" ducky_yellow: &DUCKY_YELLOW 574951975574175744 ducky_blurple: &DUCKY_BLURPLE 574951975310065675 @@ -109,74 +111,135 @@ guild: id: 267624335836053506 categories: - python_help: 356013061213126657 + python_help: 356013061213126657 channels: - admins: &ADMINS 365960823622991872 - admin_spam: &ADMIN_SPAM 563594791770914816 - admins_voice: &ADMINS_VOICE 500734494840717332 - announcements: 354619224620138496 - attachment_log: &ATTCH_LOG 649243850006855680 - big_brother_logs: &BBLOGS 468507907357409333 - bot: 267659945086812160 - checkpoint_test: 422077681434099723 - defcon: &DEFCON 464469101889454091 - devlog: &DEVLOG 622895325144940554 - devtest: &DEVTEST 414574275865870337 - esoteric: 470884583684964352 - help_0: 303906576991780866 - help_1: 303906556754395136 - help_2: 303906514266226689 - help_3: 439702951246692352 - help_4: 451312046647148554 - help_5: 454941769734422538 - help_6: 587375753306570782 - help_7: 587375768556797982 - helpers: &HELPERS 385474242440986624 - message_log: &MESSAGE_LOG 467752170159079424 - meta: 429409067623251969 - mod_spam: &MOD_SPAM 620607373828030464 - mods: &MODS 305126844661760000 - mod_alerts: 473092532147060736 - modlog: &MODLOG 282638479504965634 - off_topic_0: 291284109232308226 - off_topic_1: 463035241142026251 - off_topic_2: 463035268514185226 - organisation: &ORGANISATION 551789653284356126 - python: 267624335836053506 - reddit: 458224812528238616 - staff_lounge: &STAFF_LOUNGE 464905259261755392 - staff_voice: &STAFF_VOICE 412375055910043655 - talent_pool: &TALENT_POOL 534321732593647616 - userlog: 528976905546760203 - user_event_a: &USER_EVENT_A 592000283102674944 - verification: 352442727016693763 - voice_log: 640292421988646961 - - staff_channels: [*ADMINS, *ADMIN_SPAM, *MOD_SPAM, *MODS, *HELPERS, *ORGANISATION, *DEFCON] - ignored: [*ADMINS, *MESSAGE_LOG, *MODLOG, *ADMINS_VOICE, *STAFF_VOICE, *ATTCH_LOG] + announcements: 354619224620138496 + user_event_announcements: &USER_EVENT_A 592000283102674944 + + # Development + dev_contrib: &DEV_CONTRIB 635950537262759947 + dev_core: &DEV_CORE 411200599653351425 + dev_log: &DEV_LOG 622895325144940554 + + # Discussion + meta: 429409067623251969 + python_discussion: 267624335836053506 + + # Logs + attachment_log: &ATTACH_LOG 649243850006855680 + message_log: &MESSAGE_LOG 467752170159079424 + mod_log: &MOD_LOG 282638479504965634 + user_log: 528976905546760203 + voice_log: 640292421988646961 + + # Off-topic + off_topic_0: 291284109232308226 + off_topic_1: 463035241142026251 + off_topic_2: 463035268514185226 + + # Python Help + help_0: 303906576991780866 + help_1: 303906556754395136 + help_2: 303906514266226689 + help_3: 439702951246692352 + help_4: 451312046647148554 + help_5: 454941769734422538 + help_6: 587375753306570782 + help_7: 587375768556797982 + + # Special + bot_commands: &BOT_CMD 267659945086812160 + esoteric: 470884583684964352 + reddit: 458224812528238616 + verification: 352442727016693763 + + # Staff + admins: &ADMINS 365960823622991872 + admin_spam: &ADMIN_SPAM 563594791770914816 + defcon: &DEFCON 464469101889454091 + helpers: &HELPERS 385474242440986624 + mods: &MODS 305126844661760000 + mod_alerts: &MOD_ALERTS 473092532147060736 + mod_spam: &MOD_SPAM 620607373828030464 + organisation: &ORGANISATION 551789653284356126 + staff_lounge: &STAFF_LOUNGE 464905259261755392 + + # Voice + admins_voice: &ADMINS_VOICE 500734494840717332 + staff_voice: &STAFF_VOICE 412375055910043655 + + # Watch + big_brother_logs: &BB_LOGS 468507907357409333 + talent_pool: &TALENT_POOL 534321732593647616 + + staff_channels: + - *ADMINS + - *ADMIN_SPAM + - *DEFCON + - *HELPERS + - *MODS + - *MOD_SPAM + - *ORGANISATION + + moderation_channels: + - *ADMINS + - *ADMIN_SPAM + - *MOD_ALERTS + - *MODS + - *MOD_SPAM + + # Modlog cog ignores events which occur in these channels + modlog_blacklist: + - *ADMINS + - *ADMINS_VOICE + - *ATTACH_LOG + - *MESSAGE_LOG + - *MOD_LOG + - *STAFF_VOICE + + reminder_whitelist: + - *BOT_CMD + - *DEV_CONTRIB roles: - admin: &ADMIN_ROLE 267628507062992896 - announcements: 463658397560995840 - champion: 430492892331769857 - contributor: 295488872404484098 - core_developer: 587606783669829632 - helpers: 267630620367257601 - jammer: 591786436651646989 - moderator: &MOD_ROLE 267629731250176001 - muted: &MUTED_ROLE 277914926603829249 - owner: &OWNER_ROLE 267627879762755584 - partners: 323426753857191936 - rockstars: &ROCKSTARS_ROLE 458226413825294336 - team_leader: 501324292341104650 - verified: 352427296948486144 + announcements: 463658397560995840 + contributors: 295488872404484098 + muted: &MUTED_ROLE 277914926603829249 + partners: 323426753857191936 + python_community: &PY_COMMUNITY_ROLE 458226413825294336 + + # This is the Developers role on PyDis, here named verified for readability reasons + verified: 352427296948486144 + + # Staff + admins: &ADMINS_ROLE 267628507062992896 + core_developers: 587606783669829632 + helpers: &HELPERS_ROLE 267630620367257601 + moderators: &MODS_ROLE 267629731250176001 + owners: &OWNERS_ROLE 267627879762755584 + + # Code Jam + jammers: 591786436651646989 + team_leaders: 501324292341104650 + + moderation_roles: + - *OWNERS_ROLE + - *ADMINS_ROLE + - *MODS_ROLE + + staff_roles: + - *OWNERS_ROLE + - *ADMINS_ROLE + - *MODS_ROLE + - *HELPERS_ROLE webhooks: - talent_pool: 569145364800602132 - big_brother: 569133704568373283 - reddit: 635408384794951680 - duck_pond: 637821475327311927 + talent_pool: 569145364800602132 + big_brother: 569133704568373283 + reddit: 635408384794951680 + duck_pond: 637821475327311927 + dev_log: 680501655111729222 filter: @@ -216,6 +279,7 @@ filter: - 438622377094414346 # Pyglet - 524691714909274162 # Panda3D - 336642139381301249 # discord.py + - 405403391410438165 # Sentdex domain_blacklist: - pornhub.com @@ -253,20 +317,19 @@ filter: # Censor doesn't apply to these channel_whitelist: - *ADMINS - - *MODLOG + - *MOD_LOG - *MESSAGE_LOG - - *DEVLOG - - *BBLOGS + - *DEV_LOG + - *BB_LOGS - *STAFF_LOUNGE - - *DEVTEST - *TALENT_POOL - *USER_EVENT_A role_whitelist: - - *ADMIN_ROLE - - *MOD_ROLE - - *OWNER_ROLE - - *ROCKSTARS_ROLE + - *ADMINS_ROLE + - *MODS_ROLE + - *OWNERS_ROLE + - *PY_COMMUNITY_ROLE keys: @@ -428,9 +491,26 @@ redirect_output: delete_invocation: true delete_delay: 15 +sync: + confirm_timeout: 300 + max_diff: 10 + duck_pond: threshold: 5 - custom_emojis: [*DUCKY_YELLOW, *DUCKY_BLURPLE, *DUCKY_CAMO, *DUCKY_DEVIL, *DUCKY_NINJA, *DUCKY_REGAL, *DUCKY_TUBE, *DUCKY_HUNT, *DUCKY_WIZARD, *DUCKY_PARTY, *DUCKY_ANGEL, *DUCKY_MAUL, *DUCKY_SANTA] + custom_emojis: + - *DUCKY_YELLOW + - *DUCKY_BLURPLE + - *DUCKY_CAMO + - *DUCKY_DEVIL + - *DUCKY_NINJA + - *DUCKY_REGAL + - *DUCKY_TUBE + - *DUCKY_HUNT + - *DUCKY_WIZARD + - *DUCKY_PARTY + - *DUCKY_ANGEL + - *DUCKY_MAUL + - *DUCKY_SANTA config: required_keys: ['bot.token'] diff --git a/docker-compose.yml b/docker-compose.yml index 7281c7953..11deceae8 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -23,6 +23,7 @@ services: - staff.web ports: - "127.0.0.1:8000:8000" + tty: true depends_on: - postgres environment: @@ -37,6 +38,7 @@ services: volumes: - ./logs:/bot/logs - .:/bot:ro + tty: true depends_on: - web environment: diff --git a/tests/base.py b/tests/base.py index 029a249ed..88693f382 100644 --- a/tests/base.py +++ b/tests/base.py @@ -1,6 +1,12 @@ import logging import unittest from contextlib import contextmanager +from typing import Dict + +import discord +from discord.ext import commands + +from tests import helpers class _CaptureLogHandler(logging.Handler): @@ -65,3 +71,31 @@ class LoggingTestCase(unittest.TestCase): standard_message = self._truncateMessage(base_message, record_message) msg = self._formatMessage(msg, standard_message) self.fail(msg) + + +class CommandTestCase(unittest.TestCase): + """TestCase with additional assertions that are useful for testing Discord commands.""" + + @helpers.async_test + async def assertHasPermissionsCheck( + self, + cmd: commands.Command, + permissions: Dict[str, bool], + ) -> None: + """ + Test that `cmd` raises a `MissingPermissions` exception if author lacks `permissions`. + + Every permission in `permissions` is expected to be reported as missing. In other words, do + not include permissions which should not raise an exception along with those which should. + """ + # Invert permission values because it's more intuitive to pass to this assertion the same + # permissions as those given to the check decorator. + permissions = {k: not v for k, v in permissions.items()} + + ctx = helpers.MockContext() + ctx.channel.permissions_for.return_value = discord.Permissions(**permissions) + + with self.assertRaises(commands.MissingPermissions) as cm: + await cmd.can_run(ctx) + + self.assertCountEqual(permissions.keys(), cm.exception.missing_perms) diff --git a/tests/bot/cogs/sync/test_base.py b/tests/bot/cogs/sync/test_base.py new file mode 100644 index 000000000..c2e143865 --- /dev/null +++ b/tests/bot/cogs/sync/test_base.py @@ -0,0 +1,412 @@ +import unittest +from unittest import mock + +import discord + +from bot import constants +from bot.api import ResponseCodeError +from bot.cogs.sync.syncers import Syncer, _Diff +from tests import helpers + + +class TestSyncer(Syncer): + """Syncer subclass with mocks for abstract methods for testing purposes.""" + + name = "test" + _get_diff = helpers.AsyncMock() + _sync = helpers.AsyncMock() + + +class SyncerBaseTests(unittest.TestCase): + """Tests for the syncer base class.""" + + def setUp(self): + self.bot = helpers.MockBot() + + def test_instantiation_fails_without_abstract_methods(self): + """The class must have abstract methods implemented.""" + with self.assertRaisesRegex(TypeError, "Can't instantiate abstract class"): + Syncer(self.bot) + + +class SyncerSendPromptTests(unittest.TestCase): + """Tests for sending the sync confirmation prompt.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = TestSyncer(self.bot) + + def mock_get_channel(self): + """Fixture to return a mock channel and message for when `get_channel` is used.""" + self.bot.reset_mock() + + mock_channel = helpers.MockTextChannel() + mock_message = helpers.MockMessage() + + mock_channel.send.return_value = mock_message + self.bot.get_channel.return_value = mock_channel + + return mock_channel, mock_message + + def mock_fetch_channel(self): + """Fixture to return a mock channel and message for when `fetch_channel` is used.""" + self.bot.reset_mock() + + mock_channel = helpers.MockTextChannel() + mock_message = helpers.MockMessage() + + self.bot.get_channel.return_value = None + mock_channel.send.return_value = mock_message + self.bot.fetch_channel.return_value = mock_channel + + return mock_channel, mock_message + + @helpers.async_test + async def test_send_prompt_edits_and_returns_message(self): + """The given message should be edited to display the prompt and then should be returned.""" + msg = helpers.MockMessage() + ret_val = await self.syncer._send_prompt(msg) + + msg.edit.assert_called_once() + self.assertIn("content", msg.edit.call_args[1]) + self.assertEqual(ret_val, msg) + + @helpers.async_test + async def test_send_prompt_gets_dev_core_channel(self): + """The dev-core channel should be retrieved if an extant message isn't given.""" + subtests = ( + (self.bot.get_channel, self.mock_get_channel), + (self.bot.fetch_channel, self.mock_fetch_channel), + ) + + for method, mock_ in subtests: + with self.subTest(method=method, msg=mock_.__name__): + mock_() + await self.syncer._send_prompt() + + method.assert_called_once_with(constants.Channels.dev_core) + + @helpers.async_test + async def test_send_prompt_returns_None_if_channel_fetch_fails(self): + """None should be returned if there's an HTTPException when fetching the channel.""" + self.bot.get_channel.return_value = None + self.bot.fetch_channel.side_effect = discord.HTTPException(mock.MagicMock(), "test error!") + + ret_val = await self.syncer._send_prompt() + + self.assertIsNone(ret_val) + + @helpers.async_test + async def test_send_prompt_sends_and_returns_new_message_if_not_given(self): + """A new message mentioning core devs should be sent and returned if message isn't given.""" + for mock_ in (self.mock_get_channel, self.mock_fetch_channel): + with self.subTest(msg=mock_.__name__): + mock_channel, mock_message = mock_() + ret_val = await self.syncer._send_prompt() + + mock_channel.send.assert_called_once() + self.assertIn(self.syncer._CORE_DEV_MENTION, mock_channel.send.call_args[0][0]) + self.assertEqual(ret_val, mock_message) + + @helpers.async_test + async def test_send_prompt_adds_reactions(self): + """The message should have reactions for confirmation added.""" + extant_message = helpers.MockMessage() + subtests = ( + (extant_message, lambda: (None, extant_message)), + (None, self.mock_get_channel), + (None, self.mock_fetch_channel), + ) + + for message_arg, mock_ in subtests: + subtest_msg = "Extant message" if mock_.__name__ == "<lambda>" else mock_.__name__ + + with self.subTest(msg=subtest_msg): + _, mock_message = mock_() + await self.syncer._send_prompt(message_arg) + + calls = [mock.call(emoji) for emoji in self.syncer._REACTION_EMOJIS] + mock_message.add_reaction.assert_has_calls(calls) + + +class SyncerConfirmationTests(unittest.TestCase): + """Tests for waiting for a sync confirmation reaction on the prompt.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = TestSyncer(self.bot) + self.core_dev_role = helpers.MockRole(id=constants.Roles.core_developers) + + @staticmethod + def get_message_reaction(emoji): + """Fixture to return a mock message an reaction from the given `emoji`.""" + message = helpers.MockMessage() + reaction = helpers.MockReaction(emoji=emoji, message=message) + + return message, reaction + + def test_reaction_check_for_valid_emoji_and_authors(self): + """Should return True if authors are identical or are a bot and a core dev, respectively.""" + user_subtests = ( + ( + helpers.MockMember(id=77), + helpers.MockMember(id=77), + "identical users", + ), + ( + helpers.MockMember(id=77, bot=True), + helpers.MockMember(id=43, roles=[self.core_dev_role]), + "bot author and core-dev reactor", + ), + ) + + for emoji in self.syncer._REACTION_EMOJIS: + for author, user, msg in user_subtests: + with self.subTest(author=author, user=user, emoji=emoji, msg=msg): + message, reaction = self.get_message_reaction(emoji) + ret_val = self.syncer._reaction_check(author, message, reaction, user) + + self.assertTrue(ret_val) + + def test_reaction_check_for_invalid_reactions(self): + """Should return False for invalid reaction events.""" + valid_emoji = self.syncer._REACTION_EMOJIS[0] + subtests = ( + ( + helpers.MockMember(id=77), + *self.get_message_reaction(valid_emoji), + helpers.MockMember(id=43, roles=[self.core_dev_role]), + "users are not identical", + ), + ( + helpers.MockMember(id=77, bot=True), + *self.get_message_reaction(valid_emoji), + helpers.MockMember(id=43), + "reactor lacks the core-dev role", + ), + ( + helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]), + *self.get_message_reaction(valid_emoji), + helpers.MockMember(id=77, bot=True, roles=[self.core_dev_role]), + "reactor is a bot", + ), + ( + helpers.MockMember(id=77), + helpers.MockMessage(id=95), + helpers.MockReaction(emoji=valid_emoji, message=helpers.MockMessage(id=26)), + helpers.MockMember(id=77), + "messages are not identical", + ), + ( + helpers.MockMember(id=77), + *self.get_message_reaction("InVaLiD"), + helpers.MockMember(id=77), + "emoji is invalid", + ), + ) + + for *args, msg in subtests: + kwargs = dict(zip(("author", "message", "reaction", "user"), args)) + with self.subTest(**kwargs, msg=msg): + ret_val = self.syncer._reaction_check(*args) + self.assertFalse(ret_val) + + @helpers.async_test + async def test_wait_for_confirmation(self): + """The message should always be edited and only return True if the emoji is a check mark.""" + subtests = ( + (constants.Emojis.check_mark, True, None), + ("InVaLiD", False, None), + (None, False, TimeoutError), + ) + + for emoji, ret_val, side_effect in subtests: + for bot in (True, False): + with self.subTest(emoji=emoji, ret_val=ret_val, side_effect=side_effect, bot=bot): + # Set up mocks + message = helpers.MockMessage() + member = helpers.MockMember(bot=bot) + + self.bot.wait_for.reset_mock() + self.bot.wait_for.return_value = (helpers.MockReaction(emoji=emoji), None) + self.bot.wait_for.side_effect = side_effect + + # Call the function + actual_return = await self.syncer._wait_for_confirmation(member, message) + + # Perform assertions + self.bot.wait_for.assert_called_once() + self.assertIn("reaction_add", self.bot.wait_for.call_args[0]) + + message.edit.assert_called_once() + kwargs = message.edit.call_args[1] + self.assertIn("content", kwargs) + + # Core devs should only be mentioned if the author is a bot. + if bot: + self.assertIn(self.syncer._CORE_DEV_MENTION, kwargs["content"]) + else: + self.assertNotIn(self.syncer._CORE_DEV_MENTION, kwargs["content"]) + + self.assertIs(actual_return, ret_val) + + +class SyncerSyncTests(unittest.TestCase): + """Tests for main function orchestrating the sync.""" + + def setUp(self): + self.bot = helpers.MockBot(user=helpers.MockMember(bot=True)) + self.syncer = TestSyncer(self.bot) + + @helpers.async_test + async def test_sync_respects_confirmation_result(self): + """The sync should abort if confirmation fails and continue if confirmed.""" + mock_message = helpers.MockMessage() + subtests = ( + (True, mock_message), + (False, None), + ) + + for confirmed, message in subtests: + with self.subTest(confirmed=confirmed): + self.syncer._sync.reset_mock() + self.syncer._get_diff.reset_mock() + + diff = _Diff({1, 2, 3}, {4, 5}, None) + self.syncer._get_diff.return_value = diff + self.syncer._get_confirmation_result = helpers.AsyncMock( + return_value=(confirmed, message) + ) + + guild = helpers.MockGuild() + await self.syncer.sync(guild) + + self.syncer._get_diff.assert_called_once_with(guild) + self.syncer._get_confirmation_result.assert_called_once() + + if confirmed: + self.syncer._sync.assert_called_once_with(diff) + else: + self.syncer._sync.assert_not_called() + + @helpers.async_test + async def test_sync_diff_size(self): + """The diff size should be correctly calculated.""" + subtests = ( + (6, _Diff({1, 2}, {3, 4}, {5, 6})), + (5, _Diff({1, 2, 3}, None, {4, 5})), + (0, _Diff(None, None, None)), + (0, _Diff(set(), set(), set())), + ) + + for size, diff in subtests: + with self.subTest(size=size, diff=diff): + self.syncer._get_diff.reset_mock() + self.syncer._get_diff.return_value = diff + self.syncer._get_confirmation_result = helpers.AsyncMock(return_value=(False, None)) + + guild = helpers.MockGuild() + await self.syncer.sync(guild) + + self.syncer._get_diff.assert_called_once_with(guild) + self.syncer._get_confirmation_result.assert_called_once() + self.assertEqual(self.syncer._get_confirmation_result.call_args[0][0], size) + + @helpers.async_test + async def test_sync_message_edited(self): + """The message should be edited if one was sent, even if the sync has an API error.""" + subtests = ( + (None, None, False), + (helpers.MockMessage(), None, True), + (helpers.MockMessage(), ResponseCodeError(mock.MagicMock()), True), + ) + + for message, side_effect, should_edit in subtests: + with self.subTest(message=message, side_effect=side_effect, should_edit=should_edit): + self.syncer._sync.side_effect = side_effect + self.syncer._get_confirmation_result = helpers.AsyncMock( + return_value=(True, message) + ) + + guild = helpers.MockGuild() + await self.syncer.sync(guild) + + if should_edit: + message.edit.assert_called_once() + self.assertIn("content", message.edit.call_args[1]) + + @helpers.async_test + async def test_sync_confirmation_context_redirect(self): + """If ctx is given, a new message should be sent and author should be ctx's author.""" + mock_member = helpers.MockMember() + subtests = ( + (None, self.bot.user, None), + (helpers.MockContext(author=mock_member), mock_member, helpers.MockMessage()), + ) + + for ctx, author, message in subtests: + with self.subTest(ctx=ctx, author=author, message=message): + if ctx is not None: + ctx.send.return_value = message + + self.syncer._get_confirmation_result = helpers.AsyncMock(return_value=(False, None)) + + guild = helpers.MockGuild() + await self.syncer.sync(guild, ctx) + + if ctx is not None: + ctx.send.assert_called_once() + + self.syncer._get_confirmation_result.assert_called_once() + self.assertEqual(self.syncer._get_confirmation_result.call_args[0][1], author) + self.assertEqual(self.syncer._get_confirmation_result.call_args[0][2], message) + + @mock.patch.object(constants.Sync, "max_diff", new=3) + @helpers.async_test + async def test_confirmation_result_small_diff(self): + """Should always return True and the given message if the diff size is too small.""" + author = helpers.MockMember() + expected_message = helpers.MockMessage() + + for size in (3, 2): + with self.subTest(size=size): + self.syncer._send_prompt = helpers.AsyncMock() + self.syncer._wait_for_confirmation = helpers.AsyncMock() + + coro = self.syncer._get_confirmation_result(size, author, expected_message) + result, actual_message = await coro + + self.assertTrue(result) + self.assertEqual(actual_message, expected_message) + self.syncer._send_prompt.assert_not_called() + self.syncer._wait_for_confirmation.assert_not_called() + + @mock.patch.object(constants.Sync, "max_diff", new=3) + @helpers.async_test + async def test_confirmation_result_large_diff(self): + """Should return True if confirmed and False if _send_prompt fails or aborted.""" + author = helpers.MockMember() + mock_message = helpers.MockMessage() + + subtests = ( + (True, mock_message, True, "confirmed"), + (False, None, False, "_send_prompt failed"), + (False, mock_message, False, "aborted"), + ) + + for expected_result, expected_message, confirmed, msg in subtests: + with self.subTest(msg=msg): + self.syncer._send_prompt = helpers.AsyncMock(return_value=expected_message) + self.syncer._wait_for_confirmation = helpers.AsyncMock(return_value=confirmed) + + coro = self.syncer._get_confirmation_result(4, author) + actual_result, actual_message = await coro + + self.syncer._send_prompt.assert_called_once_with(None) # message defaults to None + self.assertIs(actual_result, expected_result) + self.assertEqual(actual_message, expected_message) + + if expected_message: + self.syncer._wait_for_confirmation.assert_called_once_with( + author, expected_message + ) diff --git a/tests/bot/cogs/sync/test_cog.py b/tests/bot/cogs/sync/test_cog.py new file mode 100644 index 000000000..98c9afc0d --- /dev/null +++ b/tests/bot/cogs/sync/test_cog.py @@ -0,0 +1,395 @@ +import unittest +from unittest import mock + +import discord + +from bot import constants +from bot.api import ResponseCodeError +from bot.cogs import sync +from bot.cogs.sync.syncers import Syncer +from tests import helpers +from tests.base import CommandTestCase + + +class MockSyncer(helpers.CustomMockMixin, mock.MagicMock): + """ + A MagicMock subclass to mock Syncer objects. + + Instances of this class will follow the specifications of `bot.cogs.sync.syncers.Syncer` + instances. For more information, see the `MockGuild` docstring. + """ + + def __init__(self, **kwargs) -> None: + super().__init__(spec_set=Syncer, **kwargs) + + +class SyncExtensionTests(unittest.TestCase): + """Tests for the sync extension.""" + + @staticmethod + def test_extension_setup(): + """The Sync cog should be added.""" + bot = helpers.MockBot() + sync.setup(bot) + bot.add_cog.assert_called_once() + + +class SyncCogTestCase(unittest.TestCase): + """Base class for Sync cog tests. Sets up patches for syncers.""" + + def setUp(self): + self.bot = helpers.MockBot() + + # These patch the type. When the type is called, a MockSyncer instanced is returned. + # MockSyncer is needed so that our custom AsyncMock is used. + # TODO: Use autospec instead in 3.8, which will automatically use AsyncMock when needed. + self.role_syncer_patcher = mock.patch( + "bot.cogs.sync.syncers.RoleSyncer", + new=mock.MagicMock(return_value=MockSyncer()) + ) + self.user_syncer_patcher = mock.patch( + "bot.cogs.sync.syncers.UserSyncer", + new=mock.MagicMock(return_value=MockSyncer()) + ) + self.RoleSyncer = self.role_syncer_patcher.start() + self.UserSyncer = self.user_syncer_patcher.start() + + self.cog = sync.Sync(self.bot) + + def tearDown(self): + self.role_syncer_patcher.stop() + self.user_syncer_patcher.stop() + + @staticmethod + def response_error(status: int) -> ResponseCodeError: + """Fixture to return a ResponseCodeError with the given status code.""" + response = mock.MagicMock() + response.status = status + + return ResponseCodeError(response) + + +class SyncCogTests(SyncCogTestCase): + """Tests for the Sync cog.""" + + @mock.patch.object(sync.Sync, "sync_guild") + def test_sync_cog_init(self, sync_guild): + """Should instantiate syncers and run a sync for the guild.""" + # Reset because a Sync cog was already instantiated in setUp. + self.RoleSyncer.reset_mock() + self.UserSyncer.reset_mock() + self.bot.loop.create_task.reset_mock() + + mock_sync_guild_coro = mock.MagicMock() + sync_guild.return_value = mock_sync_guild_coro + + sync.Sync(self.bot) + + self.RoleSyncer.assert_called_once_with(self.bot) + self.UserSyncer.assert_called_once_with(self.bot) + sync_guild.assert_called_once_with() + self.bot.loop.create_task.assert_called_once_with(mock_sync_guild_coro) + + @helpers.async_test + async def test_sync_cog_sync_guild(self): + """Roles and users should be synced only if a guild is successfully retrieved.""" + for guild in (helpers.MockGuild(), None): + with self.subTest(guild=guild): + self.bot.reset_mock() + self.cog.role_syncer.reset_mock() + self.cog.user_syncer.reset_mock() + + self.bot.get_guild = mock.MagicMock(return_value=guild) + + await self.cog.sync_guild() + + self.bot.wait_until_guild_available.assert_called_once() + self.bot.get_guild.assert_called_once_with(constants.Guild.id) + + if guild is None: + self.cog.role_syncer.sync.assert_not_called() + self.cog.user_syncer.sync.assert_not_called() + else: + self.cog.role_syncer.sync.assert_called_once_with(guild) + self.cog.user_syncer.sync.assert_called_once_with(guild) + + async def patch_user_helper(self, side_effect: BaseException) -> None: + """Helper to set a side effect for bot.api_client.patch and then assert it is called.""" + self.bot.api_client.patch.reset_mock(side_effect=True) + self.bot.api_client.patch.side_effect = side_effect + + user_id, updated_information = 5, {"key": 123} + await self.cog.patch_user(user_id, updated_information) + + self.bot.api_client.patch.assert_called_once_with( + f"bot/users/{user_id}", + json=updated_information, + ) + + @helpers.async_test + async def test_sync_cog_patch_user(self): + """A PATCH request should be sent and 404 errors ignored.""" + for side_effect in (None, self.response_error(404)): + with self.subTest(side_effect=side_effect): + await self.patch_user_helper(side_effect) + + @helpers.async_test + async def test_sync_cog_patch_user_non_404(self): + """A PATCH request should be sent and the error raised if it's not a 404.""" + with self.assertRaises(ResponseCodeError): + await self.patch_user_helper(self.response_error(500)) + + +class SyncCogListenerTests(SyncCogTestCase): + """Tests for the listeners of the Sync cog.""" + + def setUp(self): + super().setUp() + self.cog.patch_user = helpers.AsyncMock(spec_set=self.cog.patch_user) + + @helpers.async_test + async def test_sync_cog_on_guild_role_create(self): + """A POST request should be sent with the new role's data.""" + self.assertTrue(self.cog.on_guild_role_create.__cog_listener__) + + role_data = { + "colour": 49, + "id": 777, + "name": "rolename", + "permissions": 8, + "position": 23, + } + role = helpers.MockRole(**role_data) + await self.cog.on_guild_role_create(role) + + self.bot.api_client.post.assert_called_once_with("bot/roles", json=role_data) + + @helpers.async_test + async def test_sync_cog_on_guild_role_delete(self): + """A DELETE request should be sent.""" + self.assertTrue(self.cog.on_guild_role_delete.__cog_listener__) + + role = helpers.MockRole(id=99) + await self.cog.on_guild_role_delete(role) + + self.bot.api_client.delete.assert_called_once_with("bot/roles/99") + + @helpers.async_test + async def test_sync_cog_on_guild_role_update(self): + """A PUT request should be sent if the colour, name, permissions, or position changes.""" + self.assertTrue(self.cog.on_guild_role_update.__cog_listener__) + + role_data = { + "colour": 49, + "id": 777, + "name": "rolename", + "permissions": 8, + "position": 23, + } + subtests = ( + (True, ("colour", "name", "permissions", "position")), + (False, ("hoist", "mentionable")), + ) + + for should_put, attributes in subtests: + for attribute in attributes: + with self.subTest(should_put=should_put, changed_attribute=attribute): + self.bot.api_client.put.reset_mock() + + after_role_data = role_data.copy() + after_role_data[attribute] = 876 + + before_role = helpers.MockRole(**role_data) + after_role = helpers.MockRole(**after_role_data) + + await self.cog.on_guild_role_update(before_role, after_role) + + if should_put: + self.bot.api_client.put.assert_called_once_with( + f"bot/roles/{after_role.id}", + json=after_role_data + ) + else: + self.bot.api_client.put.assert_not_called() + + @helpers.async_test + async def test_sync_cog_on_member_remove(self): + """Member should patched to set in_guild as False.""" + self.assertTrue(self.cog.on_member_remove.__cog_listener__) + + member = helpers.MockMember() + await self.cog.on_member_remove(member) + + self.cog.patch_user.assert_called_once_with( + member.id, + updated_information={"in_guild": False} + ) + + @helpers.async_test + async def test_sync_cog_on_member_update_roles(self): + """Members should be patched if their roles have changed.""" + self.assertTrue(self.cog.on_member_update.__cog_listener__) + + # Roles are intentionally unsorted. + before_roles = [helpers.MockRole(id=12), helpers.MockRole(id=30), helpers.MockRole(id=20)] + before_member = helpers.MockMember(roles=before_roles) + after_member = helpers.MockMember(roles=before_roles[1:]) + + await self.cog.on_member_update(before_member, after_member) + + data = {"roles": sorted(role.id for role in after_member.roles)} + self.cog.patch_user.assert_called_once_with(after_member.id, updated_information=data) + + @helpers.async_test + async def test_sync_cog_on_member_update_other(self): + """Members should not be patched if other attributes have changed.""" + self.assertTrue(self.cog.on_member_update.__cog_listener__) + + subtests = ( + ("activities", discord.Game("Pong"), discord.Game("Frogger")), + ("nick", "old nick", "new nick"), + ("status", discord.Status.online, discord.Status.offline), + ) + + for attribute, old_value, new_value in subtests: + with self.subTest(attribute=attribute): + self.cog.patch_user.reset_mock() + + before_member = helpers.MockMember(**{attribute: old_value}) + after_member = helpers.MockMember(**{attribute: new_value}) + + await self.cog.on_member_update(before_member, after_member) + + self.cog.patch_user.assert_not_called() + + @helpers.async_test + async def test_sync_cog_on_user_update(self): + """A user should be patched only if the name, discriminator, or avatar changes.""" + self.assertTrue(self.cog.on_user_update.__cog_listener__) + + before_data = { + "name": "old name", + "discriminator": "1234", + "avatar": "old avatar", + "bot": False, + } + + subtests = ( + (True, "name", "name", "new name", "new name"), + (True, "discriminator", "discriminator", "8765", 8765), + (True, "avatar", "avatar_hash", "9j2e9", "9j2e9"), + (False, "bot", "bot", True, True), + ) + + for should_patch, attribute, api_field, value, api_value in subtests: + with self.subTest(attribute=attribute): + self.cog.patch_user.reset_mock() + + after_data = before_data.copy() + after_data[attribute] = value + before_user = helpers.MockUser(**before_data) + after_user = helpers.MockUser(**after_data) + + await self.cog.on_user_update(before_user, after_user) + + if should_patch: + self.cog.patch_user.assert_called_once() + + # Don't care if *all* keys are present; only the changed one is required + call_args = self.cog.patch_user.call_args + self.assertEqual(call_args[0][0], after_user.id) + self.assertIn("updated_information", call_args[1]) + + updated_information = call_args[1]["updated_information"] + self.assertIn(api_field, updated_information) + self.assertEqual(updated_information[api_field], api_value) + else: + self.cog.patch_user.assert_not_called() + + async def on_member_join_helper(self, side_effect: Exception) -> dict: + """ + Helper to set `side_effect` for on_member_join and assert a PUT request was sent. + + The request data for the mock member is returned. All exceptions will be re-raised. + """ + member = helpers.MockMember( + discriminator="1234", + roles=[helpers.MockRole(id=22), helpers.MockRole(id=12)], + ) + + data = { + "avatar_hash": member.avatar, + "discriminator": int(member.discriminator), + "id": member.id, + "in_guild": True, + "name": member.name, + "roles": sorted(role.id for role in member.roles) + } + + self.bot.api_client.put.reset_mock(side_effect=True) + self.bot.api_client.put.side_effect = side_effect + + try: + await self.cog.on_member_join(member) + except Exception: + raise + finally: + self.bot.api_client.put.assert_called_once_with( + f"bot/users/{member.id}", + json=data + ) + + return data + + @helpers.async_test + async def test_sync_cog_on_member_join(self): + """Should PUT user's data or POST it if the user doesn't exist.""" + for side_effect in (None, self.response_error(404)): + with self.subTest(side_effect=side_effect): + self.bot.api_client.post.reset_mock() + data = await self.on_member_join_helper(side_effect) + + if side_effect: + self.bot.api_client.post.assert_called_once_with("bot/users", json=data) + else: + self.bot.api_client.post.assert_not_called() + + @helpers.async_test + async def test_sync_cog_on_member_join_non_404(self): + """ResponseCodeError should be re-raised if status code isn't a 404.""" + with self.assertRaises(ResponseCodeError): + await self.on_member_join_helper(self.response_error(500)) + + self.bot.api_client.post.assert_not_called() + + +class SyncCogCommandTests(SyncCogTestCase, CommandTestCase): + """Tests for the commands in the Sync cog.""" + + @helpers.async_test + async def test_sync_roles_command(self): + """sync() should be called on the RoleSyncer.""" + ctx = helpers.MockContext() + await self.cog.sync_roles_command.callback(self.cog, ctx) + + self.cog.role_syncer.sync.assert_called_once_with(ctx.guild, ctx) + + @helpers.async_test + async def test_sync_users_command(self): + """sync() should be called on the UserSyncer.""" + ctx = helpers.MockContext() + await self.cog.sync_users_command.callback(self.cog, ctx) + + self.cog.user_syncer.sync.assert_called_once_with(ctx.guild, ctx) + + def test_commands_require_admin(self): + """The sync commands should only run if the author has the administrator permission.""" + cmds = ( + self.cog.sync_group, + self.cog.sync_roles_command, + self.cog.sync_users_command, + ) + + for cmd in cmds: + with self.subTest(cmd=cmd): + self.assertHasPermissionsCheck(cmd, {"administrator": True}) diff --git a/tests/bot/cogs/sync/test_roles.py b/tests/bot/cogs/sync/test_roles.py index 27ae27639..14fb2577a 100644 --- a/tests/bot/cogs/sync/test_roles.py +++ b/tests/bot/cogs/sync/test_roles.py @@ -1,126 +1,165 @@ import unittest +from unittest import mock -from bot.cogs.sync.syncers import Role, get_roles_for_sync - - -class GetRolesForSyncTests(unittest.TestCase): - """Tests constructing the roles to synchronize with the site.""" - - def test_get_roles_for_sync_empty_return_for_equal_roles(self): - """No roles should be synced when no diff is found.""" - api_roles = {Role(id=41, name='name', colour=33, permissions=0x8, position=1)} - guild_roles = {Role(id=41, name='name', colour=33, permissions=0x8, position=1)} - - self.assertEqual( - get_roles_for_sync(guild_roles, api_roles), - (set(), set(), set()) - ) - - def test_get_roles_for_sync_returns_roles_to_update_with_non_id_diff(self): - """Roles to be synced are returned when non-ID attributes differ.""" - api_roles = {Role(id=41, name='old name', colour=35, permissions=0x8, position=1)} - guild_roles = {Role(id=41, name='new name', colour=33, permissions=0x8, position=2)} - - self.assertEqual( - get_roles_for_sync(guild_roles, api_roles), - (set(), guild_roles, set()) - ) - - def test_get_roles_only_returns_roles_that_require_update(self): - """Roles that require an update should be returned as the second tuple element.""" - api_roles = { - Role(id=41, name='old name', colour=33, permissions=0x8, position=1), - Role(id=53, name='other role', colour=55, permissions=0, position=3) - } - guild_roles = { - Role(id=41, name='new name', colour=35, permissions=0x8, position=2), - Role(id=53, name='other role', colour=55, permissions=0, position=3) - } - - self.assertEqual( - get_roles_for_sync(guild_roles, api_roles), - ( - set(), - {Role(id=41, name='new name', colour=35, permissions=0x8, position=2)}, - set(), - ) - ) - - def test_get_roles_returns_new_roles_in_first_tuple_element(self): - """Newly created roles are returned as the first tuple element.""" - api_roles = { - Role(id=41, name='name', colour=35, permissions=0x8, position=1), - } - guild_roles = { - Role(id=41, name='name', colour=35, permissions=0x8, position=1), - Role(id=53, name='other role', colour=55, permissions=0, position=2) - } - - self.assertEqual( - get_roles_for_sync(guild_roles, api_roles), - ( - {Role(id=53, name='other role', colour=55, permissions=0, position=2)}, - set(), - set(), - ) - ) - - def test_get_roles_returns_roles_to_update_and_new_roles(self): - """Newly created and updated roles should be returned together.""" - api_roles = { - Role(id=41, name='old name', colour=35, permissions=0x8, position=1), - } - guild_roles = { - Role(id=41, name='new name', colour=40, permissions=0x16, position=2), - Role(id=53, name='other role', colour=55, permissions=0, position=3) - } - - self.assertEqual( - get_roles_for_sync(guild_roles, api_roles), - ( - {Role(id=53, name='other role', colour=55, permissions=0, position=3)}, - {Role(id=41, name='new name', colour=40, permissions=0x16, position=2)}, - set(), - ) - ) - - def test_get_roles_returns_roles_to_delete(self): - """Roles to be deleted should be returned as the third tuple element.""" - api_roles = { - Role(id=41, name='name', colour=35, permissions=0x8, position=1), - Role(id=61, name='to delete', colour=99, permissions=0x9, position=2), - } - guild_roles = { - Role(id=41, name='name', colour=35, permissions=0x8, position=1), - } - - self.assertEqual( - get_roles_for_sync(guild_roles, api_roles), - ( - set(), - set(), - {Role(id=61, name='to delete', colour=99, permissions=0x9, position=2)}, - ) - ) - - def test_get_roles_returns_roles_to_delete_update_and_new_roles(self): - """When roles were added, updated, and removed, all of them are returned properly.""" - api_roles = { - Role(id=41, name='not changed', colour=35, permissions=0x8, position=1), - Role(id=61, name='to delete', colour=99, permissions=0x9, position=2), - Role(id=71, name='to update', colour=99, permissions=0x9, position=3), - } - guild_roles = { - Role(id=41, name='not changed', colour=35, permissions=0x8, position=1), - Role(id=81, name='to create', colour=99, permissions=0x9, position=4), - Role(id=71, name='updated', colour=101, permissions=0x5, position=3), - } - - self.assertEqual( - get_roles_for_sync(guild_roles, api_roles), - ( - {Role(id=81, name='to create', colour=99, permissions=0x9, position=4)}, - {Role(id=71, name='updated', colour=101, permissions=0x5, position=3)}, - {Role(id=61, name='to delete', colour=99, permissions=0x9, position=2)}, - ) - ) +import discord + +from bot.cogs.sync.syncers import RoleSyncer, _Diff, _Role +from tests import helpers + + +def fake_role(**kwargs): + """Fixture to return a dictionary representing a role with default values set.""" + kwargs.setdefault("id", 9) + kwargs.setdefault("name", "fake role") + kwargs.setdefault("colour", 7) + kwargs.setdefault("permissions", 0) + kwargs.setdefault("position", 55) + + return kwargs + + +class RoleSyncerDiffTests(unittest.TestCase): + """Tests for determining differences between roles in the DB and roles in the Guild cache.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = RoleSyncer(self.bot) + + @staticmethod + def get_guild(*roles): + """Fixture to return a guild object with the given roles.""" + guild = helpers.MockGuild() + guild.roles = [] + + for role in roles: + mock_role = helpers.MockRole(**role) + mock_role.colour = discord.Colour(role["colour"]) + mock_role.permissions = discord.Permissions(role["permissions"]) + guild.roles.append(mock_role) + + return guild + + @helpers.async_test + async def test_empty_diff_for_identical_roles(self): + """No differences should be found if the roles in the guild and DB are identical.""" + self.bot.api_client.get.return_value = [fake_role()] + guild = self.get_guild(fake_role()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), set()) + + self.assertEqual(actual_diff, expected_diff) + + @helpers.async_test + async def test_diff_for_updated_roles(self): + """Only updated roles should be added to the 'updated' set of the diff.""" + updated_role = fake_role(id=41, name="new") + + self.bot.api_client.get.return_value = [fake_role(id=41, name="old"), fake_role()] + guild = self.get_guild(updated_role, fake_role()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), {_Role(**updated_role)}, set()) + + self.assertEqual(actual_diff, expected_diff) + + @helpers.async_test + async def test_diff_for_new_roles(self): + """Only new roles should be added to the 'created' set of the diff.""" + new_role = fake_role(id=41, name="new") + + self.bot.api_client.get.return_value = [fake_role()] + guild = self.get_guild(fake_role(), new_role) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = ({_Role(**new_role)}, set(), set()) + + self.assertEqual(actual_diff, expected_diff) + + @helpers.async_test + async def test_diff_for_deleted_roles(self): + """Only deleted roles should be added to the 'deleted' set of the diff.""" + deleted_role = fake_role(id=61, name="deleted") + + self.bot.api_client.get.return_value = [fake_role(), deleted_role] + guild = self.get_guild(fake_role()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), {_Role(**deleted_role)}) + + self.assertEqual(actual_diff, expected_diff) + + @helpers.async_test + async def test_diff_for_new_updated_and_deleted_roles(self): + """When roles are added, updated, and removed, all of them are returned properly.""" + new = fake_role(id=41, name="new") + updated = fake_role(id=71, name="updated") + deleted = fake_role(id=61, name="deleted") + + self.bot.api_client.get.return_value = [ + fake_role(), + fake_role(id=71, name="updated name"), + deleted, + ] + guild = self.get_guild(fake_role(), new, updated) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = ({_Role(**new)}, {_Role(**updated)}, {_Role(**deleted)}) + + self.assertEqual(actual_diff, expected_diff) + + +class RoleSyncerSyncTests(unittest.TestCase): + """Tests for the API requests that sync roles.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = RoleSyncer(self.bot) + + @helpers.async_test + async def test_sync_created_roles(self): + """Only POST requests should be made with the correct payload.""" + roles = [fake_role(id=111), fake_role(id=222)] + + role_tuples = {_Role(**role) for role in roles} + diff = _Diff(role_tuples, set(), set()) + await self.syncer._sync(diff) + + calls = [mock.call("bot/roles", json=role) for role in roles] + self.bot.api_client.post.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.post.call_count, len(roles)) + + self.bot.api_client.put.assert_not_called() + self.bot.api_client.delete.assert_not_called() + + @helpers.async_test + async def test_sync_updated_roles(self): + """Only PUT requests should be made with the correct payload.""" + roles = [fake_role(id=111), fake_role(id=222)] + + role_tuples = {_Role(**role) for role in roles} + diff = _Diff(set(), role_tuples, set()) + await self.syncer._sync(diff) + + calls = [mock.call(f"bot/roles/{role['id']}", json=role) for role in roles] + self.bot.api_client.put.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.put.call_count, len(roles)) + + self.bot.api_client.post.assert_not_called() + self.bot.api_client.delete.assert_not_called() + + @helpers.async_test + async def test_sync_deleted_roles(self): + """Only DELETE requests should be made with the correct payload.""" + roles = [fake_role(id=111), fake_role(id=222)] + + role_tuples = {_Role(**role) for role in roles} + diff = _Diff(set(), set(), role_tuples) + await self.syncer._sync(diff) + + calls = [mock.call(f"bot/roles/{role['id']}") for role in roles] + self.bot.api_client.delete.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.delete.call_count, len(roles)) + + self.bot.api_client.post.assert_not_called() + self.bot.api_client.put.assert_not_called() diff --git a/tests/bot/cogs/sync/test_users.py b/tests/bot/cogs/sync/test_users.py index ccaf67490..421bf6bb6 100644 --- a/tests/bot/cogs/sync/test_users.py +++ b/tests/bot/cogs/sync/test_users.py @@ -1,84 +1,169 @@ import unittest +from unittest import mock -from bot.cogs.sync.syncers import User, get_users_for_sync +from bot.cogs.sync.syncers import UserSyncer, _Diff, _User +from tests import helpers def fake_user(**kwargs): - kwargs.setdefault('id', 43) - kwargs.setdefault('name', 'bob the test man') - kwargs.setdefault('discriminator', 1337) - kwargs.setdefault('avatar_hash', None) - kwargs.setdefault('roles', (666,)) - kwargs.setdefault('in_guild', True) - return User(**kwargs) - - -class GetUsersForSyncTests(unittest.TestCase): - """Tests constructing the users to synchronize with the site.""" - - def test_get_users_for_sync_returns_nothing_for_empty_params(self): - """When no users are given, none are returned.""" - self.assertEqual( - get_users_for_sync({}, {}), - (set(), set()) - ) - - def test_get_users_for_sync_returns_nothing_for_equal_users(self): - """When no users are updated, none are returned.""" - api_users = {43: fake_user()} - guild_users = {43: fake_user()} - - self.assertEqual( - get_users_for_sync(guild_users, api_users), - (set(), set()) - ) - - def test_get_users_for_sync_returns_users_to_update_on_non_id_field_diff(self): - """When a non-ID-field differs, the user to update is returned.""" - api_users = {43: fake_user()} - guild_users = {43: fake_user(name='new fancy name')} - - self.assertEqual( - get_users_for_sync(guild_users, api_users), - (set(), {fake_user(name='new fancy name')}) - ) - - def test_get_users_for_sync_returns_users_to_create_with_new_ids_on_guild(self): - """When new users join the guild, they are returned as the first tuple element.""" - api_users = {43: fake_user()} - guild_users = {43: fake_user(), 63: fake_user(id=63)} - - self.assertEqual( - get_users_for_sync(guild_users, api_users), - ({fake_user(id=63)}, set()) - ) - - def test_get_users_for_sync_updates_in_guild_field_on_user_leave(self): + """Fixture to return a dictionary representing a user with default values set.""" + kwargs.setdefault("id", 43) + kwargs.setdefault("name", "bob the test man") + kwargs.setdefault("discriminator", 1337) + kwargs.setdefault("avatar_hash", None) + kwargs.setdefault("roles", (666,)) + kwargs.setdefault("in_guild", True) + + return kwargs + + +class UserSyncerDiffTests(unittest.TestCase): + """Tests for determining differences between users in the DB and users in the Guild cache.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = UserSyncer(self.bot) + + @staticmethod + def get_guild(*members): + """Fixture to return a guild object with the given members.""" + guild = helpers.MockGuild() + guild.members = [] + + for member in members: + member = member.copy() + member["avatar"] = member.pop("avatar_hash") + del member["in_guild"] + + mock_member = helpers.MockMember(**member) + mock_member.roles = [helpers.MockRole(id=role_id) for role_id in member["roles"]] + + guild.members.append(mock_member) + + return guild + + @helpers.async_test + async def test_empty_diff_for_no_users(self): + """When no users are given, an empty diff should be returned.""" + guild = self.get_guild() + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), None) + + self.assertEqual(actual_diff, expected_diff) + + @helpers.async_test + async def test_empty_diff_for_identical_users(self): + """No differences should be found if the users in the guild and DB are identical.""" + self.bot.api_client.get.return_value = [fake_user()] + guild = self.get_guild(fake_user()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), None) + + self.assertEqual(actual_diff, expected_diff) + + @helpers.async_test + async def test_diff_for_updated_users(self): + """Only updated users should be added to the 'updated' set of the diff.""" + updated_user = fake_user(id=99, name="new") + + self.bot.api_client.get.return_value = [fake_user(id=99, name="old"), fake_user()] + guild = self.get_guild(updated_user, fake_user()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), {_User(**updated_user)}, None) + + self.assertEqual(actual_diff, expected_diff) + + @helpers.async_test + async def test_diff_for_new_users(self): + """Only new users should be added to the 'created' set of the diff.""" + new_user = fake_user(id=99, name="new") + + self.bot.api_client.get.return_value = [fake_user()] + guild = self.get_guild(fake_user(), new_user) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = ({_User(**new_user)}, set(), None) + + self.assertEqual(actual_diff, expected_diff) + + @helpers.async_test + async def test_diff_sets_in_guild_false_for_leaving_users(self): """When a user leaves the guild, the `in_guild` flag is updated to `False`.""" - api_users = {43: fake_user(), 63: fake_user(id=63)} - guild_users = {43: fake_user()} - - self.assertEqual( - get_users_for_sync(guild_users, api_users), - (set(), {fake_user(id=63, in_guild=False)}) - ) - - def test_get_users_for_sync_updates_and_creates_users_as_needed(self): - """When one user left and another one was updated, both are returned.""" - api_users = {43: fake_user()} - guild_users = {63: fake_user(id=63)} - - self.assertEqual( - get_users_for_sync(guild_users, api_users), - ({fake_user(id=63)}, {fake_user(in_guild=False)}) - ) - - def test_get_users_for_sync_does_not_duplicate_update_users(self): - """When the API knows a user the guild doesn't, nothing is performed.""" - api_users = {43: fake_user(in_guild=False)} - guild_users = {} - - self.assertEqual( - get_users_for_sync(guild_users, api_users), - (set(), set()) - ) + leaving_user = fake_user(id=63, in_guild=False) + + self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63)] + guild = self.get_guild(fake_user()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), {_User(**leaving_user)}, None) + + self.assertEqual(actual_diff, expected_diff) + + @helpers.async_test + async def test_diff_for_new_updated_and_leaving_users(self): + """When users are added, updated, and removed, all of them are returned properly.""" + new_user = fake_user(id=99, name="new") + updated_user = fake_user(id=55, name="updated") + leaving_user = fake_user(id=63, in_guild=False) + + self.bot.api_client.get.return_value = [fake_user(), fake_user(id=55), fake_user(id=63)] + guild = self.get_guild(fake_user(), new_user, updated_user) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = ({_User(**new_user)}, {_User(**updated_user), _User(**leaving_user)}, None) + + self.assertEqual(actual_diff, expected_diff) + + @helpers.async_test + async def test_empty_diff_for_db_users_not_in_guild(self): + """When the DB knows a user the guild doesn't, no difference is found.""" + self.bot.api_client.get.return_value = [fake_user(), fake_user(id=63, in_guild=False)] + guild = self.get_guild(fake_user()) + + actual_diff = await self.syncer._get_diff(guild) + expected_diff = (set(), set(), None) + + self.assertEqual(actual_diff, expected_diff) + + +class UserSyncerSyncTests(unittest.TestCase): + """Tests for the API requests that sync users.""" + + def setUp(self): + self.bot = helpers.MockBot() + self.syncer = UserSyncer(self.bot) + + @helpers.async_test + async def test_sync_created_users(self): + """Only POST requests should be made with the correct payload.""" + users = [fake_user(id=111), fake_user(id=222)] + + user_tuples = {_User(**user) for user in users} + diff = _Diff(user_tuples, set(), None) + await self.syncer._sync(diff) + + calls = [mock.call("bot/users", json=user) for user in users] + self.bot.api_client.post.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.post.call_count, len(users)) + + self.bot.api_client.put.assert_not_called() + self.bot.api_client.delete.assert_not_called() + + @helpers.async_test + async def test_sync_updated_users(self): + """Only PUT requests should be made with the correct payload.""" + users = [fake_user(id=111), fake_user(id=222)] + + user_tuples = {_User(**user) for user in users} + diff = _Diff(set(), user_tuples, None) + await self.syncer._sync(diff) + + calls = [mock.call(f"bot/users/{user['id']}", json=user) for user in users] + self.bot.api_client.put.assert_has_calls(calls, any_order=True) + self.assertEqual(self.bot.api_client.put.call_count, len(users)) + + self.bot.api_client.post.assert_not_called() + self.bot.api_client.delete.assert_not_called() diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py index d07b2bce1..5b0a3b8c3 100644 --- a/tests/bot/cogs/test_duck_pond.py +++ b/tests/bot/cogs/test_duck_pond.py @@ -54,7 +54,7 @@ class DuckPondTests(base.LoggingTestCase): asyncio.run(self.cog.fetch_webhook()) - self.bot.wait_until_ready.assert_called_once() + self.bot.wait_until_guild_available.assert_called_once() self.bot.fetch_webhook.assert_called_once_with(1) self.assertEqual(self.cog.webhook, "dummy webhook") @@ -67,7 +67,7 @@ class DuckPondTests(base.LoggingTestCase): with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: asyncio.run(self.cog.fetch_webhook()) - self.bot.wait_until_ready.assert_called_once() + self.bot.wait_until_guild_available.assert_called_once() self.bot.fetch_webhook.assert_called_once_with(1) self.assertEqual(len(log_watcher.records), 1) diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py index 4496a2ae0..8443cfe71 100644 --- a/tests/bot/cogs/test_information.py +++ b/tests/bot/cogs/test_information.py @@ -19,7 +19,7 @@ class InformationCogTests(unittest.TestCase): @classmethod def setUpClass(cls): - cls.moderator_role = helpers.MockRole(name="Moderator", id=constants.Roles.moderator) + cls.moderator_role = helpers.MockRole(name="Moderator", id=constants.Roles.moderators) def setUp(self): """Sets up fresh objects for each test.""" @@ -125,10 +125,10 @@ class InformationCogTests(unittest.TestCase): ) ], members=[ - *(helpers.MockMember(status='online') for _ in range(2)), - *(helpers.MockMember(status='idle') for _ in range(1)), - *(helpers.MockMember(status='dnd') for _ in range(4)), - *(helpers.MockMember(status='offline') for _ in range(3)), + *(helpers.MockMember(status=discord.Status.online) for _ in range(2)), + *(helpers.MockMember(status=discord.Status.idle) for _ in range(1)), + *(helpers.MockMember(status=discord.Status.dnd) for _ in range(4)), + *(helpers.MockMember(status=discord.Status.offline) for _ in range(3)), ], member_count=1_234, icon_url='a-lemon.jpg', @@ -153,9 +153,9 @@ class InformationCogTests(unittest.TestCase): **Counts** Members: {self.ctx.guild.member_count:,} Roles: {len(self.ctx.guild.roles)} - Text: 1 - Voice: 1 - Channel categories: 1 + Category channels: 1 + Text channels: 1 + Voice channels: 1 **Members** {constants.Emojis.status_online} 2 @@ -521,7 +521,7 @@ class UserCommandTests(unittest.TestCase): """A regular user should not be able to use this command outside of bot-commands.""" constants.MODERATION_ROLES = [self.moderator_role.id] constants.STAFF_ROLES = [self.moderator_role.id] - constants.Channels.bot = 50 + constants.Channels.bot_commands = 50 ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=100)) @@ -533,7 +533,7 @@ class UserCommandTests(unittest.TestCase): def test_regular_user_may_use_command_in_bot_commands_channel(self, create_embed, constants): """A regular user should be allowed to use `!user` targeting themselves in bot-commands.""" constants.STAFF_ROLES = [self.moderator_role.id] - constants.Channels.bot = 50 + constants.Channels.bot_commands = 50 ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=50)) @@ -546,7 +546,7 @@ class UserCommandTests(unittest.TestCase): def test_regular_user_can_explicitly_target_themselves(self, create_embed, constants): """A user should target itself with `!user` when a `user` argument was not provided.""" constants.STAFF_ROLES = [self.moderator_role.id] - constants.Channels.bot = 50 + constants.Channels.bot_commands = 50 ctx = helpers.MockContext(author=self.author, channel=helpers.MockTextChannel(id=50)) @@ -559,7 +559,7 @@ class UserCommandTests(unittest.TestCase): def test_staff_members_can_bypass_channel_restriction(self, create_embed, constants): """Staff members should be able to bypass the bot-commands channel restriction.""" constants.STAFF_ROLES = [self.moderator_role.id] - constants.Channels.bot = 50 + constants.Channels.bot_commands = 50 ctx = helpers.MockContext(author=self.moderator, channel=helpers.MockTextChannel(id=200)) diff --git a/tests/bot/rules/__init__.py b/tests/bot/rules/__init__.py index e69de29bb..36c986fe1 100644 --- a/tests/bot/rules/__init__.py +++ b/tests/bot/rules/__init__.py @@ -0,0 +1,76 @@ +import unittest +from abc import ABCMeta, abstractmethod +from typing import Callable, Dict, Iterable, List, NamedTuple, Tuple + +from tests.helpers import MockMessage + + +class DisallowedCase(NamedTuple): + """Encapsulation for test cases expected to fail.""" + recent_messages: List[MockMessage] + culprits: Iterable[str] + n_violations: int + + +class RuleTest(unittest.TestCase, metaclass=ABCMeta): + """ + Abstract class for antispam rule test cases. + + Tests for specific rules should inherit from `RuleTest` and implement + `relevant_messages` and `get_report`. Each instance should also set the + `apply` and `config` attributes as necessary. + + The execution of test cases can then be delegated to the `run_allowed` + and `run_disallowed` methods. + """ + + apply: Callable # The tested rule's apply function + config: Dict[str, int] + + async def run_allowed(self, cases: Tuple[List[MockMessage], ...]) -> None: + """Run all `cases` against `self.apply` expecting them to pass.""" + for recent_messages in cases: + last_message = recent_messages[0] + + with self.subTest( + last_message=last_message, + recent_messages=recent_messages, + config=self.config, + ): + self.assertIsNone( + await self.apply(last_message, recent_messages, self.config) + ) + + async def run_disallowed(self, cases: Tuple[DisallowedCase, ...]) -> None: + """Run all `cases` against `self.apply` expecting them to fail.""" + for case in cases: + recent_messages, culprits, n_violations = case + last_message = recent_messages[0] + relevant_messages = self.relevant_messages(case) + desired_output = ( + self.get_report(case), + culprits, + relevant_messages, + ) + + with self.subTest( + last_message=last_message, + recent_messages=recent_messages, + relevant_messages=relevant_messages, + n_violations=n_violations, + config=self.config, + ): + self.assertTupleEqual( + await self.apply(last_message, recent_messages, self.config), + desired_output, + ) + + @abstractmethod + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + """Give expected relevant messages for `case`.""" + raise NotImplementedError + + @abstractmethod + def get_report(self, case: DisallowedCase) -> str: + """Give expected error report for `case`.""" + raise NotImplementedError diff --git a/tests/bot/rules/test_attachments.py b/tests/bot/rules/test_attachments.py index d7187f315..e54b4b5b8 100644 --- a/tests/bot/rules/test_attachments.py +++ b/tests/bot/rules/test_attachments.py @@ -1,98 +1,71 @@ -import unittest -from typing import List, NamedTuple, Tuple +from typing import Iterable from bot.rules import attachments +from tests.bot.rules import DisallowedCase, RuleTest from tests.helpers import MockMessage, async_test -class Case(NamedTuple): - recent_messages: List[MockMessage] - culprit: Tuple[str] - total_attachments: int - - -def msg(author: str, total_attachments: int) -> MockMessage: +def make_msg(author: str, total_attachments: int) -> MockMessage: """Builds a message with `total_attachments` attachments.""" return MockMessage(author=author, attachments=list(range(total_attachments))) -class AttachmentRuleTests(unittest.TestCase): +class AttachmentRuleTests(RuleTest): """Tests applying the `attachments` antispam rule.""" def setUp(self): - self.config = {"max": 5} + self.apply = attachments.apply + self.config = {"max": 5, "interval": 10} @async_test async def test_allows_messages_without_too_many_attachments(self): """Messages without too many attachments are allowed as-is.""" cases = ( - [msg("bob", 0), msg("bob", 0), msg("bob", 0)], - [msg("bob", 2), msg("bob", 2)], - [msg("bob", 2), msg("alice", 2), msg("bob", 2)], + [make_msg("bob", 0), make_msg("bob", 0), make_msg("bob", 0)], + [make_msg("bob", 2), make_msg("bob", 2)], + [make_msg("bob", 2), make_msg("alice", 2), make_msg("bob", 2)], ) - for recent_messages in cases: - last_message = recent_messages[0] - - with self.subTest( - last_message=last_message, - recent_messages=recent_messages, - config=self.config - ): - self.assertIsNone( - await attachments.apply(last_message, recent_messages, self.config) - ) + await self.run_allowed(cases) @async_test async def test_disallows_messages_with_too_many_attachments(self): """Messages with too many attachments trigger the rule.""" cases = ( - Case( - [msg("bob", 4), msg("bob", 0), msg("bob", 6)], + DisallowedCase( + [make_msg("bob", 4), make_msg("bob", 0), make_msg("bob", 6)], ("bob",), - 10 + 10, ), - Case( - [msg("bob", 4), msg("alice", 6), msg("bob", 2)], + DisallowedCase( + [make_msg("bob", 4), make_msg("alice", 6), make_msg("bob", 2)], ("bob",), - 6 + 6, ), - Case( - [msg("alice", 6)], + DisallowedCase( + [make_msg("alice", 6)], ("alice",), - 6 + 6, ), - ( - [msg("alice", 1) for _ in range(6)], + DisallowedCase( + [make_msg("alice", 1) for _ in range(6)], ("alice",), - 6 + 6, ), ) - for recent_messages, culprit, total_attachments in cases: - last_message = recent_messages[0] - relevant_messages = tuple( - msg - for msg in recent_messages - if ( - msg.author == last_message.author - and len(msg.attachments) > 0 - ) + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + last_message = case.recent_messages[0] + return tuple( + msg + for msg in case.recent_messages + if ( + msg.author == last_message.author + and len(msg.attachments) > 0 ) + ) - with self.subTest( - last_message=last_message, - recent_messages=recent_messages, - relevant_messages=relevant_messages, - total_attachments=total_attachments, - config=self.config - ): - desired_output = ( - f"sent {total_attachments} attachments in {self.config['max']}s", - culprit, - relevant_messages - ) - self.assertTupleEqual( - await attachments.apply(last_message, recent_messages, self.config), - desired_output - ) + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} attachments in {self.config['interval']}s" diff --git a/tests/bot/rules/test_burst.py b/tests/bot/rules/test_burst.py new file mode 100644 index 000000000..72f0be0c7 --- /dev/null +++ b/tests/bot/rules/test_burst.py @@ -0,0 +1,56 @@ +from typing import Iterable + +from bot.rules import burst +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + + +def make_msg(author: str) -> MockMessage: + """ + Init a MockMessage instance with author set to `author`. + + This serves as a shorthand / alias to keep the test cases visually clean. + """ + return MockMessage(author=author) + + +class BurstRuleTests(RuleTest): + """Tests the `burst` antispam rule.""" + + def setUp(self): + self.apply = burst.apply + self.config = {"max": 2, "interval": 10} + + @async_test + async def test_allows_messages_within_limit(self): + """Cases which do not violate the rule.""" + cases = ( + [make_msg("bob"), make_msg("bob")], + [make_msg("bob"), make_msg("alice"), make_msg("bob")], + ) + + await self.run_allowed(cases) + + @async_test + async def test_disallows_messages_beyond_limit(self): + """Cases where the amount of messages exceeds the limit, triggering the rule.""" + cases = ( + DisallowedCase( + [make_msg("bob"), make_msg("bob"), make_msg("bob")], + ("bob",), + 3, + ), + DisallowedCase( + [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")], + ("bob",), + 3, + ), + ) + + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + return tuple(msg for msg in case.recent_messages if msg.author in case.culprits) + + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} messages in {self.config['interval']}s" diff --git a/tests/bot/rules/test_burst_shared.py b/tests/bot/rules/test_burst_shared.py new file mode 100644 index 000000000..47367a5f8 --- /dev/null +++ b/tests/bot/rules/test_burst_shared.py @@ -0,0 +1,59 @@ +from typing import Iterable + +from bot.rules import burst_shared +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + + +def make_msg(author: str) -> MockMessage: + """ + Init a MockMessage instance with the passed arg. + + This serves as a shorthand / alias to keep the test cases visually clean. + """ + return MockMessage(author=author) + + +class BurstSharedRuleTests(RuleTest): + """Tests the `burst_shared` antispam rule.""" + + def setUp(self): + self.apply = burst_shared.apply + self.config = {"max": 2, "interval": 10} + + @async_test + async def test_allows_messages_within_limit(self): + """ + Cases that do not violate the rule. + + There really isn't more to test here than a single case. + """ + cases = ( + [make_msg("spongebob"), make_msg("patrick")], + ) + + await self.run_allowed(cases) + + @async_test + async def test_disallows_messages_beyond_limit(self): + """Cases where the amount of messages exceeds the limit, triggering the rule.""" + cases = ( + DisallowedCase( + [make_msg("bob"), make_msg("bob"), make_msg("bob")], + {"bob"}, + 3, + ), + DisallowedCase( + [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")], + {"bob", "alice"}, + 4, + ), + ) + + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + return case.recent_messages + + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} messages in {self.config['interval']}s" diff --git a/tests/bot/rules/test_chars.py b/tests/bot/rules/test_chars.py new file mode 100644 index 000000000..7cc36f49e --- /dev/null +++ b/tests/bot/rules/test_chars.py @@ -0,0 +1,66 @@ +from typing import Iterable + +from bot.rules import chars +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + + +def make_msg(author: str, n_chars: int) -> MockMessage: + """Build a message with arbitrary content of `n_chars` length.""" + return MockMessage(author=author, content="A" * n_chars) + + +class CharsRuleTests(RuleTest): + """Tests the `chars` antispam rule.""" + + def setUp(self): + self.apply = chars.apply + self.config = { + "max": 20, # Max allowed sum of chars per user + "interval": 10, + } + + @async_test + async def test_allows_messages_within_limit(self): + """Cases with a total amount of chars within limit.""" + cases = ( + [make_msg("bob", 0)], + [make_msg("bob", 20)], + [make_msg("bob", 15), make_msg("alice", 15)], + ) + + await self.run_allowed(cases) + + @async_test + async def test_disallows_messages_beyond_limit(self): + """Cases where the total amount of chars exceeds the limit, triggering the rule.""" + cases = ( + DisallowedCase( + [make_msg("bob", 21)], + ("bob",), + 21, + ), + DisallowedCase( + [make_msg("bob", 15), make_msg("bob", 15)], + ("bob",), + 30, + ), + DisallowedCase( + [make_msg("alice", 15), make_msg("bob", 20), make_msg("alice", 15)], + ("alice",), + 30, + ), + ) + + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + last_message = case.recent_messages[0] + return tuple( + msg + for msg in case.recent_messages + if msg.author == last_message.author + ) + + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} characters in {self.config['interval']}s" diff --git a/tests/bot/rules/test_discord_emojis.py b/tests/bot/rules/test_discord_emojis.py new file mode 100644 index 000000000..0239b0b00 --- /dev/null +++ b/tests/bot/rules/test_discord_emojis.py @@ -0,0 +1,54 @@ +from typing import Iterable + +from bot.rules import discord_emojis +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + +discord_emoji = "<:abcd:1234>" # Discord emojis follow the format <:name:id> + + +def make_msg(author: str, n_emojis: int) -> MockMessage: + """Build a MockMessage instance with content containing `n_emojis` arbitrary emojis.""" + return MockMessage(author=author, content=discord_emoji * n_emojis) + + +class DiscordEmojisRuleTests(RuleTest): + """Tests for the `discord_emojis` antispam rule.""" + + def setUp(self): + self.apply = discord_emojis.apply + self.config = {"max": 2, "interval": 10} + + @async_test + async def test_allows_messages_within_limit(self): + """Cases with a total amount of discord emojis within limit.""" + cases = ( + [make_msg("bob", 2)], + [make_msg("alice", 1), make_msg("bob", 2), make_msg("alice", 1)], + ) + + await self.run_allowed(cases) + + @async_test + async def test_disallows_messages_beyond_limit(self): + """Cases with more than the allowed amount of discord emojis.""" + cases = ( + DisallowedCase( + [make_msg("bob", 3)], + ("bob",), + 3, + ), + DisallowedCase( + [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)], + ("alice",), + 4, + ), + ) + + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + return tuple(msg for msg in case.recent_messages if msg.author in case.culprits) + + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} emojis in {self.config['interval']}s" diff --git a/tests/bot/rules/test_duplicates.py b/tests/bot/rules/test_duplicates.py new file mode 100644 index 000000000..59e0fb6ef --- /dev/null +++ b/tests/bot/rules/test_duplicates.py @@ -0,0 +1,66 @@ +from typing import Iterable + +from bot.rules import duplicates +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + + +def make_msg(author: str, content: str) -> MockMessage: + """Give a MockMessage instance with `author` and `content` attrs.""" + return MockMessage(author=author, content=content) + + +class DuplicatesRuleTests(RuleTest): + """Tests the `duplicates` antispam rule.""" + + def setUp(self): + self.apply = duplicates.apply + self.config = {"max": 2, "interval": 10} + + @async_test + async def test_allows_messages_within_limit(self): + """Cases which do not violate the rule.""" + cases = ( + [make_msg("alice", "A"), make_msg("alice", "A")], + [make_msg("alice", "A"), make_msg("alice", "B"), make_msg("alice", "C")], # Non-duplicate + [make_msg("alice", "A"), make_msg("bob", "A"), make_msg("alice", "A")], # Different author + ) + + await self.run_allowed(cases) + + @async_test + async def test_disallows_messages_beyond_limit(self): + """Cases with too many duplicate messages from the same author.""" + cases = ( + DisallowedCase( + [make_msg("alice", "A"), make_msg("alice", "A"), make_msg("alice", "A")], + ("alice",), + 3, + ), + DisallowedCase( + [make_msg("bob", "A"), make_msg("alice", "A"), make_msg("bob", "A"), make_msg("bob", "A")], + ("bob",), + 3, # 4 duplicate messages, but only 3 from bob + ), + DisallowedCase( + [make_msg("bob", "A"), make_msg("bob", "B"), make_msg("bob", "A"), make_msg("bob", "A")], + ("bob",), + 3, # 4 message from bob, but only 3 duplicates + ), + ) + + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + last_message = case.recent_messages[0] + return tuple( + msg + for msg in case.recent_messages + if ( + msg.author == last_message.author + and msg.content == last_message.content + ) + ) + + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} duplicated messages in {self.config['interval']}s" diff --git a/tests/bot/rules/test_links.py b/tests/bot/rules/test_links.py index 02a5d5501..3c3f90e5f 100644 --- a/tests/bot/rules/test_links.py +++ b/tests/bot/rules/test_links.py @@ -1,26 +1,21 @@ -import unittest -from typing import List, NamedTuple, Tuple +from typing import Iterable from bot.rules import links +from tests.bot.rules import DisallowedCase, RuleTest from tests.helpers import MockMessage, async_test -class Case(NamedTuple): - recent_messages: List[MockMessage] - culprit: Tuple[str] - total_links: int - - -def msg(author: str, total_links: int) -> MockMessage: +def make_msg(author: str, total_links: int) -> MockMessage: """Makes a message with `total_links` links.""" content = " ".join(["https://pydis.com"] * total_links) return MockMessage(author=author, content=content) -class LinksTests(unittest.TestCase): +class LinksTests(RuleTest): """Tests applying the `links` rule.""" def setUp(self): + self.apply = links.apply self.config = { "max": 2, "interval": 10 @@ -30,68 +25,45 @@ class LinksTests(unittest.TestCase): async def test_links_within_limit(self): """Messages with an allowed amount of links.""" cases = ( - [msg("bob", 0)], - [msg("bob", 2)], - [msg("bob", 3)], # Filter only applies if len(messages_with_links) > 1 - [msg("bob", 1), msg("bob", 1)], - [msg("bob", 2), msg("alice", 2)] # Only messages from latest author count + [make_msg("bob", 0)], + [make_msg("bob", 2)], + [make_msg("bob", 3)], # Filter only applies if len(messages_with_links) > 1 + [make_msg("bob", 1), make_msg("bob", 1)], + [make_msg("bob", 2), make_msg("alice", 2)] # Only messages from latest author count ) - for recent_messages in cases: - last_message = recent_messages[0] - - with self.subTest( - last_message=last_message, - recent_messages=recent_messages, - config=self.config - ): - self.assertIsNone( - await links.apply(last_message, recent_messages, self.config) - ) + await self.run_allowed(cases) @async_test async def test_links_exceeding_limit(self): """Messages with a a higher than allowed amount of links.""" cases = ( - Case( - [msg("bob", 1), msg("bob", 2)], + DisallowedCase( + [make_msg("bob", 1), make_msg("bob", 2)], ("bob",), 3 ), - Case( - [msg("alice", 1), msg("alice", 1), msg("alice", 1)], + DisallowedCase( + [make_msg("alice", 1), make_msg("alice", 1), make_msg("alice", 1)], ("alice",), 3 ), - Case( - [msg("alice", 2), msg("bob", 3), msg("alice", 1)], + DisallowedCase( + [make_msg("alice", 2), make_msg("bob", 3), make_msg("alice", 1)], ("alice",), 3 ) ) - for recent_messages, culprit, total_links in cases: - last_message = recent_messages[0] - relevant_messages = tuple( - msg - for msg in recent_messages - if msg.author == last_message.author - ) + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + last_message = case.recent_messages[0] + return tuple( + msg + for msg in case.recent_messages + if msg.author == last_message.author + ) - with self.subTest( - last_message=last_message, - recent_messages=recent_messages, - relevant_messages=relevant_messages, - culprit=culprit, - total_links=total_links, - config=self.config - ): - desired_output = ( - f"sent {total_links} links in {self.config['interval']}s", - culprit, - relevant_messages - ) - self.assertTupleEqual( - await links.apply(last_message, recent_messages, self.config), - desired_output - ) + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} links in {self.config['interval']}s" diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py index ad49ead32..ebcdabac6 100644 --- a/tests/bot/rules/test_mentions.py +++ b/tests/bot/rules/test_mentions.py @@ -1,95 +1,67 @@ -import unittest -from typing import List, NamedTuple, Tuple +from typing import Iterable from bot.rules import mentions +from tests.bot.rules import DisallowedCase, RuleTest from tests.helpers import MockMessage, async_test -class Case(NamedTuple): - recent_messages: List[MockMessage] - culprit: Tuple[str] - total_mentions: int - - -def msg(author: str, total_mentions: int) -> MockMessage: +def make_msg(author: str, total_mentions: int) -> MockMessage: """Makes a message with `total_mentions` mentions.""" return MockMessage(author=author, mentions=list(range(total_mentions))) -class TestMentions(unittest.TestCase): +class TestMentions(RuleTest): """Tests applying the `mentions` antispam rule.""" def setUp(self): + self.apply = mentions.apply self.config = { "max": 2, - "interval": 10 + "interval": 10, } @async_test async def test_mentions_within_limit(self): """Messages with an allowed amount of mentions.""" cases = ( - [msg("bob", 0)], - [msg("bob", 2)], - [msg("bob", 1), msg("bob", 1)], - [msg("bob", 1), msg("alice", 2)] + [make_msg("bob", 0)], + [make_msg("bob", 2)], + [make_msg("bob", 1), make_msg("bob", 1)], + [make_msg("bob", 1), make_msg("alice", 2)], ) - for recent_messages in cases: - last_message = recent_messages[0] - - with self.subTest( - last_message=last_message, - recent_messages=recent_messages, - config=self.config - ): - self.assertIsNone( - await mentions.apply(last_message, recent_messages, self.config) - ) + await self.run_allowed(cases) @async_test async def test_mentions_exceeding_limit(self): """Messages with a higher than allowed amount of mentions.""" cases = ( - Case( - [msg("bob", 3)], + DisallowedCase( + [make_msg("bob", 3)], ("bob",), - 3 + 3, ), - Case( - [msg("alice", 2), msg("alice", 0), msg("alice", 1)], + DisallowedCase( + [make_msg("alice", 2), make_msg("alice", 0), make_msg("alice", 1)], ("alice",), - 3 + 3, ), - Case( - [msg("bob", 2), msg("alice", 3), msg("bob", 2)], + DisallowedCase( + [make_msg("bob", 2), make_msg("alice", 3), make_msg("bob", 2)], ("bob",), - 4 + 4, ) ) - for recent_messages, culprit, total_mentions in cases: - last_message = recent_messages[0] - relevant_messages = tuple( - msg - for msg in recent_messages - if msg.author == last_message.author - ) + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + last_message = case.recent_messages[0] + return tuple( + msg + for msg in case.recent_messages + if msg.author == last_message.author + ) - with self.subTest( - last_message=last_message, - recent_messages=recent_messages, - relevant_messages=relevant_messages, - culprit=culprit, - total_mentions=total_mentions, - cofig=self.config - ): - desired_output = ( - f"sent {total_mentions} mentions in {self.config['interval']}s", - culprit, - relevant_messages - ) - self.assertTupleEqual( - await mentions.apply(last_message, recent_messages, self.config), - desired_output - ) + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} mentions in {self.config['interval']}s" diff --git a/tests/bot/rules/test_newlines.py b/tests/bot/rules/test_newlines.py new file mode 100644 index 000000000..d61c4609d --- /dev/null +++ b/tests/bot/rules/test_newlines.py @@ -0,0 +1,105 @@ +from typing import Iterable, List + +from bot.rules import newlines +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + + +def make_msg(author: str, newline_groups: List[int]) -> MockMessage: + """Init a MockMessage instance with `author` and content configured by `newline_groups". + + Configure content by passing a list of ints, where each int `n` will generate + a separate group of `n` newlines. + + Example: + newline_groups=[3, 1, 2] -> content="\n\n\n \n \n\n" + """ + content = " ".join("\n" * n for n in newline_groups) + return MockMessage(author=author, content=content) + + +class TotalNewlinesRuleTests(RuleTest): + """Tests the `newlines` antispam rule against allowed cases and total newline count violations.""" + + def setUp(self): + self.apply = newlines.apply + self.config = { + "max": 5, # Max sum of newlines in relevant messages + "max_consecutive": 3, # Max newlines in one group, in one message + "interval": 10, + } + + @async_test + async def test_allows_messages_within_limit(self): + """Cases which do not violate the rule.""" + cases = ( + [make_msg("alice", [])], # Single message with no newlines + [make_msg("alice", [1, 2]), make_msg("alice", [1, 1])], # 5 newlines in 2 messages + [make_msg("alice", [2, 2, 1]), make_msg("bob", [2, 3])], # 5 newlines from each author + [make_msg("bob", [1]), make_msg("alice", [5])], # Alice breaks the rule, but only bob is relevant + ) + + await self.run_allowed(cases) + + @async_test + async def test_disallows_messages_total(self): + """Cases which violate the rule by having too many newlines in total.""" + cases = ( + DisallowedCase( # Alice sends a total of 6 newlines (disallowed) + [make_msg("alice", [2, 2]), make_msg("alice", [2])], + ("alice",), + 6, + ), + DisallowedCase( # Here we test that only alice's newlines count in the sum + [make_msg("alice", [2, 2]), make_msg("bob", [3]), make_msg("alice", [3])], + ("alice",), + 7, + ), + ) + + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + last_author = case.recent_messages[0].author + return tuple(msg for msg in case.recent_messages if msg.author == last_author) + + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} newlines in {self.config['interval']}s" + + +class GroupNewlinesRuleTests(RuleTest): + """ + Tests the `newlines` antispam rule against max consecutive newline violations. + + As these violations yield a different error report, they require a different + `get_report` implementation. + """ + + def setUp(self): + self.apply = newlines.apply + self.config = {"max": 5, "max_consecutive": 3, "interval": 10} + + @async_test + async def test_disallows_messages_consecutive(self): + """Cases which violate the rule due to having too many consecutive newlines.""" + cases = ( + DisallowedCase( # Bob sends a group of newlines too large + [make_msg("bob", [4])], + ("bob",), + 4, + ), + DisallowedCase( # Alice sends 5 in total (allowed), but 4 in one group (disallowed) + [make_msg("alice", [1]), make_msg("alice", [4])], + ("alice",), + 4, + ), + ) + + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + last_author = case.recent_messages[0].author + return tuple(msg for msg in case.recent_messages if msg.author == last_author) + + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} consecutive newlines in {self.config['interval']}s" diff --git a/tests/bot/rules/test_role_mentions.py b/tests/bot/rules/test_role_mentions.py new file mode 100644 index 000000000..b339cccf7 --- /dev/null +++ b/tests/bot/rules/test_role_mentions.py @@ -0,0 +1,57 @@ +from typing import Iterable + +from bot.rules import role_mentions +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + + +def make_msg(author: str, n_mentions: int) -> MockMessage: + """Build a MockMessage instance with `n_mentions` role mentions.""" + return MockMessage(author=author, role_mentions=[None] * n_mentions) + + +class RoleMentionsRuleTests(RuleTest): + """Tests for the `role_mentions` antispam rule.""" + + def setUp(self): + self.apply = role_mentions.apply + self.config = {"max": 2, "interval": 10} + + @async_test + async def test_allows_messages_within_limit(self): + """Cases with a total amount of role mentions within limit.""" + cases = ( + [make_msg("bob", 2)], + [make_msg("bob", 1), make_msg("alice", 1), make_msg("bob", 1)], + ) + + await self.run_allowed(cases) + + @async_test + async def test_disallows_messages_beyond_limit(self): + """Cases with more than the allowed amount of role mentions.""" + cases = ( + DisallowedCase( + [make_msg("bob", 3)], + ("bob",), + 3, + ), + DisallowedCase( + [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)], + ("alice",), + 4, + ), + ) + + await self.run_disallowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + last_message = case.recent_messages[0] + return tuple( + msg + for msg in case.recent_messages + if msg.author == last_message.author + ) + + def get_report(self, case: DisallowedCase) -> str: + return f"sent {case.n_violations} role mentions in {self.config['interval']}s" diff --git a/tests/bot/test_api.py b/tests/bot/test_api.py index 5a88adc5c..bdfcc73e4 100644 --- a/tests/bot/test_api.py +++ b/tests/bot/test_api.py @@ -1,9 +1,7 @@ -import logging import unittest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock from bot import api -from tests.base import LoggingTestCase from tests.helpers import async_test @@ -34,7 +32,7 @@ class APIClientTests(unittest.TestCase): self.assertEqual(error.response_text, "") self.assertIs(error.response, self.error_api_response) - def test_responde_code_error_string_representation_default_initialization(self): + def test_response_code_error_string_representation_default_initialization(self): """Test the string representation of `ResponseCodeError` initialized without text or json.""" error = api.ResponseCodeError(response=self.error_api_response) self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: ") @@ -76,61 +74,3 @@ class APIClientTests(unittest.TestCase): response_text=text_data ) self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: {text_data}") - - -class LoggingHandlerTests(LoggingTestCase): - """Tests the bot's API Log Handler.""" - - @classmethod - def setUpClass(cls): - cls.debug_log_record = logging.LogRecord( - name='my.logger', level=logging.DEBUG, - pathname='my/logger.py', lineno=666, - msg="Lemon wins", args=(), - exc_info=None - ) - - cls.trace_log_record = logging.LogRecord( - name='my.logger', level=logging.TRACE, - pathname='my/logger.py', lineno=666, - msg="This will not be logged", args=(), - exc_info=None - ) - - def setUp(self): - self.log_handler = api.APILoggingHandler(None) - - def test_emit_appends_to_queue_with_stopped_event_loop(self): - """Test if `APILoggingHandler.emit` appends to queue when the event loop is not running.""" - with patch("bot.api.APILoggingHandler.ship_off") as ship_off: - # Patch `ship_off` to ease testing against the return value of this coroutine. - ship_off.return_value = 42 - self.log_handler.emit(self.debug_log_record) - - self.assertListEqual(self.log_handler.queue, [42]) - - def test_emit_ignores_less_than_debug(self): - """`APILoggingHandler.emit` should not queue logs with a log level lower than DEBUG.""" - self.log_handler.emit(self.trace_log_record) - self.assertListEqual(self.log_handler.queue, []) - - def test_schedule_queued_tasks_for_empty_queue(self): - """`APILoggingHandler` should not schedule anything when the queue is empty.""" - with self.assertNotLogs(level=logging.DEBUG): - self.log_handler.schedule_queued_tasks() - - def test_schedule_queued_tasks_for_nonempty_queue(self): - """`APILoggingHandler` should schedule logs when the queue is not empty.""" - log = logging.getLogger("bot.api") - - with self.assertLogs(logger=log, level=logging.DEBUG) as logs, patch('asyncio.create_task') as create_task: - self.log_handler.queue = [555] - self.log_handler.schedule_queued_tasks() - self.assertListEqual(self.log_handler.queue, []) - create_task.assert_called_once_with(555) - - [record] = logs.records - self.assertEqual(record.message, "Scheduled 1 pending logging tasks.") - self.assertEqual(record.levelno, logging.DEBUG) - self.assertEqual(record.name, 'bot.api') - self.assertIn('via_handler', record.__dict__) diff --git a/tests/bot/test_utils.py b/tests/bot/test_utils.py index 58ae2a81a..d7bcc3ba6 100644 --- a/tests/bot/test_utils.py +++ b/tests/bot/test_utils.py @@ -35,18 +35,3 @@ class CaseInsensitiveDictTests(unittest.TestCase): instance = utils.CaseInsensitiveDict() instance.update({'FOO': 'bar'}) self.assertEqual(instance['foo'], 'bar') - - -class ChunkTests(unittest.TestCase): - """Tests the `chunk` method.""" - - def test_empty_chunking(self): - """Tests chunking on an empty iterable.""" - generator = utils.chunks(iterable=[], size=5) - self.assertEqual(list(generator), []) - - def test_list_chunking(self): - """Tests chunking a non-empty list.""" - iterable = [1, 2, 3, 4, 5] - generator = utils.chunks(iterable=iterable, size=2) - self.assertEqual(list(generator), [[1, 2], [3, 4], [5]]) diff --git a/tests/helpers.py b/tests/helpers.py index 6aee8623f..6f50f6ae3 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -12,6 +12,7 @@ from typing import Any, Iterable, Optional import discord from discord.ext.commands import Context +from bot.api import APIClient from bot.bot import Bot @@ -281,9 +282,21 @@ class MockRole(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): information, see the `MockGuild` docstring. """ def __init__(self, **kwargs) -> None: - default_kwargs = {'id': next(self.discord_id), 'name': 'role', 'position': 1} + default_kwargs = { + 'id': next(self.discord_id), + 'name': 'role', + 'position': 1, + 'colour': discord.Colour(0xdeadbf), + 'permissions': discord.Permissions(), + } super().__init__(spec_set=role_instance, **collections.ChainMap(kwargs, default_kwargs)) + if isinstance(self.colour, int): + self.colour = discord.Colour(self.colour) + + if isinstance(self.permissions, int): + self.permissions = discord.Permissions(self.permissions) + if 'mention' not in kwargs: self.mention = f'&{self.name}' @@ -336,6 +349,18 @@ class MockUser(CustomMockMixin, unittest.mock.Mock, ColourMixin, HashableMixin): self.mention = f"@{self.name}" +class MockAPIClient(CustomMockMixin, unittest.mock.MagicMock): + """ + A MagicMock subclass to mock APIClient objects. + + Instances of this class will follow the specifications of `bot.api.APIClient` instances. + For more information, see the `MockGuild` docstring. + """ + + def __init__(self, **kwargs) -> None: + super().__init__(spec_set=APIClient, **kwargs) + + # Create a Bot instance to get a realistic MagicMock of `discord.ext.commands.Bot` bot_instance = Bot(command_prefix=unittest.mock.MagicMock()) bot_instance.http_session = None @@ -352,6 +377,7 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): def __init__(self, **kwargs) -> None: super().__init__(spec_set=bot_instance, **kwargs) + self.api_client = MockAPIClient() # self.wait_for is *not* a coroutine function, but returns a coroutine nonetheless and # and should therefore be awaited. (The documentation calls it a coroutine as well, which @@ -515,6 +541,7 @@ class MockReaction(CustomMockMixin, unittest.mock.MagicMock): self.emoji = kwargs.get('emoji', MockEmoji()) self.message = kwargs.get('message', MockMessage()) self.users = AsyncIteratorMock(kwargs.get('users', [])) + self.__str__.return_value = str(self.emoji) webhook_instance = discord.Webhook(data=unittest.mock.MagicMock(), adapter=unittest.mock.MagicMock()) @@ -3,7 +3,7 @@ max-line-length=120 docstring-convention=all import-order-style=pycharm application_import_names=bot,tests -exclude=.cache,.venv,constants.py +exclude=.cache,.venv,.git,constants.py ignore= B311,W503,E226,S311,T000 # Missing Docstrings @@ -15,5 +15,5 @@ ignore= # Docstring Content D400,D401,D402,D404,D405,D406,D407,D408,D409,D410,D411,D412,D413,D414,D416,D417 # Type Annotations - TYP002,TYP003,TYP101,TYP102,TYP204,TYP206 -per-file-ignores=tests/*:D,TYP + ANN002,ANN003,ANN101,ANN102,ANN204,ANN206 +per-file-ignores=tests/*:D,ANN |