diff options
| -rw-r--r-- | Pipfile | 4 | ||||
| -rw-r--r-- | Pipfile.lock | 130 | ||||
| -rw-r--r-- | azure-pipelines.yml | 2 | ||||
| -rw-r--r-- | bot/__init__.py | 87 | ||||
| -rw-r--r-- | bot/__main__.py | 13 | ||||
| -rw-r--r-- | bot/api.py | 74 | ||||
| -rw-r--r-- | bot/bot.py | 2 | ||||
| -rw-r--r-- | bot/cogs/defcon.py | 4 | ||||
| -rw-r--r-- | bot/cogs/error_handler.py | 26 | ||||
| -rw-r--r-- | bot/constants.py | 2 | ||||
| -rw-r--r-- | bot/rules/attachments.py | 2 | ||||
| -rw-r--r-- | config-default.yml | 1 | ||||
| -rw-r--r-- | tests/bot/rules/__init__.py | 76 | ||||
| -rw-r--r-- | tests/bot/rules/test_attachments.py | 97 | ||||
| -rw-r--r-- | tests/bot/rules/test_burst.py | 56 | ||||
| -rw-r--r-- | tests/bot/rules/test_burst_shared.py | 59 | ||||
| -rw-r--r-- | tests/bot/rules/test_chars.py | 66 | ||||
| -rw-r--r-- | tests/bot/rules/test_discord_emojis.py | 54 | ||||
| -rw-r--r-- | tests/bot/rules/test_duplicates.py | 66 | ||||
| -rw-r--r-- | tests/bot/rules/test_links.py | 84 | ||||
| -rw-r--r-- | tests/bot/rules/test_mentions.py | 90 | ||||
| -rw-r--r-- | tests/bot/rules/test_newlines.py | 105 | ||||
| -rw-r--r-- | tests/bot/rules/test_role_mentions.py | 57 | ||||
| -rw-r--r-- | tests/bot/test_api.py | 64 | ||||
| -rw-r--r-- | tox.ini | 6 | 
25 files changed, 783 insertions, 444 deletions
@@ -6,7 +6,6 @@ name = "pypi"  [packages]  discord-py = "~=1.3.1"  aiodns = "~=2.0" -logmatic-python = "~=0.1"  aiohttp = "~=3.5"  sphinx = "~=2.2"  markdownify = "~=0.4" @@ -19,11 +18,12 @@ deepdiff = "~=4.0"  requests = "~=2.22"  more_itertools = "~=7.2"  urllib3 = ">=1.24.2,<1.25" +sentry-sdk = "~=0.14"  [dev-packages]  coverage = "~=4.5"  flake8 = "~=3.7" -flake8-annotations = "~=1.1" +flake8-annotations = "~=2.0"  flake8-bugbear = "~=19.8"  flake8-docstrings = "~=1.4"  flake8-import-order = "~=0.18" diff --git a/Pipfile.lock b/Pipfile.lock index 9e09260f1..6a8a982a4 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@  {      "_meta": {          "hash": { -            "sha256": "0a0354a8cbd25b19c61b68f928493a445e737dc6447c97f4c4b52fbf72d887ac" +            "sha256": "c7706a61eb96c06d073898018ea2dbcf5bd3b15d007496e2d60120a65647f31e"          },          "pipfile-spec": 6,          "requires": { @@ -191,13 +191,6 @@              ],              "version": "==2.11.1"          }, -        "logmatic-python": { -            "hashes": [ -                "sha256:0c15ac9f5faa6a60059b28910db642c3dc7722948c3cc940923f8c9039604342" -            ], -            "index": "pypi", -            "version": "==0.1.7" -        },          "lxml": {              "hashes": [                  "sha256:06d4e0bbb1d62e38ae6118406d7cdb4693a3fa34ee3762238bcb96c9e36a93cd", @@ -286,25 +279,25 @@          },          "multidict": {              "hashes": [ -                "sha256:13f3ebdb5693944f52faa7b2065b751cb7e578b8dd0a5bb8e4ab05ad0188b85e", -                "sha256:26502cefa86d79b86752e96639352c7247846515c864d7c2eb85d036752b643c", -                "sha256:4fba5204d32d5c52439f88437d33ad14b5f228e25072a192453f658bddfe45a7", -                "sha256:527124ef435f39a37b279653ad0238ff606b58328ca7989a6df372fd75d7fe26", -                "sha256:5414f388ffd78c57e77bd253cf829373721f450613de53dc85a08e34d806e8eb", -                "sha256:5eee66f882ab35674944dfa0d28b57fa51e160b4dce0ce19e47f495fdae70703", -                "sha256:63810343ea07f5cd86ba66ab66706243a6f5af075eea50c01e39b4ad6bc3c57a", -                "sha256:6bd10adf9f0d6a98ccc792ab6f83d18674775986ba9bacd376b643fe35633357", -                "sha256:83c6ddf0add57c6b8a7de0bc7e2d656be3eefeff7c922af9a9aae7e49f225625", -                "sha256:93166e0f5379cf6cd29746989f8a594fa7204dcae2e9335ddba39c870a287e1c", -                "sha256:9a7b115ee0b9b92d10ebc246811d8f55d0c57e82dbb6a26b23c9a9a6ad40ce0c", -                "sha256:a38baa3046cce174a07a59952c9f876ae8875ef3559709639c17fdf21f7b30dd", -                "sha256:a6d219f49821f4b2c85c6d426346a5d84dab6daa6f85ca3da6c00ed05b54022d", -                "sha256:a8ed33e8f9b67e3b592c56567135bb42e7e0e97417a4b6a771e60898dfd5182b", -                "sha256:d7d428488c67b09b26928950a395e41cc72bb9c3d5abfe9f0521940ee4f796d4", -                "sha256:dcfed56aa085b89d644af17442cdc2debaa73388feba4b8026446d168ca8dad7", -                "sha256:f29b885e4903bd57a7789f09fe9d60b6475a6c1a4c0eca874d8558f00f9d4b51" -            ], -            "version": "==4.7.4" +                "sha256:317f96bc0950d249e96d8d29ab556d01dd38888fbe68324f46fd834b430169f1", +                "sha256:42f56542166040b4474c0c608ed051732033cd821126493cf25b6c276df7dd35", +                "sha256:4b7df040fb5fe826d689204f9b544af469593fb3ff3a069a6ad3409f742f5928", +                "sha256:544fae9261232a97102e27a926019100a9db75bec7b37feedd74b3aa82f29969", +                "sha256:620b37c3fea181dab09267cd5a84b0f23fa043beb8bc50d8474dd9694de1fa6e", +                "sha256:6e6fef114741c4d7ca46da8449038ec8b1e880bbe68674c01ceeb1ac8a648e78", +                "sha256:7774e9f6c9af3f12f296131453f7b81dabb7ebdb948483362f5afcaac8a826f1", +                "sha256:85cb26c38c96f76b7ff38b86c9d560dea10cf3459bb5f4caf72fc1bb932c7136", +                "sha256:a326f4240123a2ac66bb163eeba99578e9d63a8654a59f4688a79198f9aa10f8", +                "sha256:ae402f43604e3b2bc41e8ea8b8526c7fa7139ed76b0d64fc48e28125925275b2", +                "sha256:aee283c49601fa4c13adc64c09c978838a7e812f85377ae130a24d7198c0331e", +                "sha256:b51249fdd2923739cd3efc95a3d6c363b67bbf779208e9f37fd5e68540d1a4d4", +                "sha256:bb519becc46275c594410c6c28a8a0adc66fe24fef154a9addea54c1adb006f5", +                "sha256:c2c37185fb0af79d5c117b8d2764f4321eeb12ba8c141a95d0aa8c2c1d0a11dd", +                "sha256:dc561313279f9d05a3d0ffa89cd15ae477528ea37aa9795c4654588a3287a9ab", +                "sha256:e439c9a10a95cb32abd708bb8be83b2134fa93790a4fb0535ca36db3dda94d20", +                "sha256:fc3b4adc2ee8474cb3cd2a155305d5f8eda0a9c91320f83e55748e1fcb68f8e3" +            ], +            "version": "==4.7.5"          },          "ordered-set": {              "hashes": [ @@ -388,12 +381,6 @@              "index": "pypi",              "version": "==2.8.1"          }, -        "python-json-logger": { -            "hashes": [ -                "sha256:b7a31162f2a01965a5efb94453ce69230ed208468b0bbc7fdfc56e6d8df2e281" -            ], -            "version": "==0.1.11" -        },          "pytz": {              "hashes": [                  "sha256:1c557d7d0e871de1f5ccd5833f60fb2550652da6be2693c1e02300743d21500d", @@ -426,6 +413,14 @@              "index": "pypi",              "version": "==2.23.0"          }, +        "sentry-sdk": { +            "hashes": [ +                "sha256:b06dd27391fd11fb32f84fe054e6a64736c469514a718a99fb5ce1dff95d6b28", +                "sha256:e023da07cfbead3868e1e2ba994160517885a32dfd994fc455b118e37989479b" +            ], +            "index": "pypi", +            "version": "==0.14.1" +        },          "six": {              "hashes": [                  "sha256:236bdbdce46e6e6a3d61a337c0f8b763ca1e8717c03b369e87a7ec7ce1319c0a", @@ -449,11 +444,11 @@          },          "sphinx": {              "hashes": [ -                "sha256:525527074f2e0c2585f68f73c99b4dc257c34bbe308b27f5f8c7a6e20642742f", -                "sha256:543d39db5f82d83a5c1aa0c10c88f2b6cff2da3e711aa849b2c627b4b403bbd9" +                "sha256:776ff8333181138fae52df65be733127539623bb46cc692e7fa0fcfc80d7aa88", +                "sha256:ca762da97c3b5107cbf0ab9e11d3ec7ab8d3c31377266fd613b962ed971df709"              ],              "index": "pypi", -            "version": "==2.4.2" +            "version": "==2.4.3"          },          "sphinxcontrib-applehelp": {              "hashes": [ @@ -471,10 +466,10 @@          },          "sphinxcontrib-htmlhelp": {              "hashes": [ -                "sha256:4670f99f8951bd78cd4ad2ab962f798f5618b17675c35c5ac3b2132a14ea8422", -                "sha256:d4fd39a65a625c9df86d7fa8a2d9f3cd8299a3a4b15db63b50aac9e161d8eff7" +                "sha256:3c0bc24a2c41e340ac37c85ced6dafc879ab485c095b1d65d2461ac2f7cca86f", +                "sha256:e8f5bb7e31b2dbb25b9cc435c8ab7a79787ebf7f906155729338f3156d93659b"              ], -            "version": "==1.0.2" +            "version": "==1.0.3"          },          "sphinxcontrib-jsmath": {              "hashes": [ @@ -645,7 +640,8 @@          },          "distlib": {              "hashes": [ -                "sha256:2e166e231a26b36d6dfe35a48c4464346620f8645ed0ace01ee31822b288de21" +                "sha256:2e166e231a26b36d6dfe35a48c4464346620f8645ed0ace01ee31822b288de21", +                "sha256:3db50260f17a3479465fe376211c0816e6b0d1503a6c71caebe80360cab04828"              ],              "version": "==0.3.0"          }, @@ -688,11 +684,11 @@          },          "flake8-annotations": {              "hashes": [ -                "sha256:47705be09c6e56e9e3ac1656e8f5ed70862a4657116dc472f5a56c1bdc5172b1", -                "sha256:564702ace354e1059252755be79d082a70ae1851c86044ae1a96d0f5453280e9" +                "sha256:19a6637a5da1bb7ea7948483ca9e2b9e15b213e687e7bf5ff8c1bfc91c185006", +                "sha256:bb033b72cdd3a2b0a530bbdf2081f12fbea7d70baeaaebb5899723a45f424b8e"              ],              "index": "pypi", -            "version": "==1.2.0" +            "version": "==2.0.0"          },          "flake8-bugbear": {              "hashes": [ @@ -755,6 +751,14 @@              ],              "version": "==2.9"          }, +        "importlib-metadata": { +            "hashes": [ +                "sha256:06f5b3a99029c7134207dd882428a66992a9de2bef7c2b699b5641f9886c3302", +                "sha256:b97607a1a18a5100839aec1dc26a1ea17ee0d93b20b0f008d80a5a050afb200b" +            ], +            "markers": "python_version < '3.8'", +            "version": "==1.5.0" +        },          "mccabe": {              "hashes": [                  "sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42", @@ -865,6 +869,33 @@              ],              "version": "==0.10.0"          }, +        "typed-ast": { +            "hashes": [ +                "sha256:0666aa36131496aed8f7be0410ff974562ab7eeac11ef351def9ea6fa28f6355", +                "sha256:0c2c07682d61a629b68433afb159376e24e5b2fd4641d35424e462169c0a7919", +                "sha256:249862707802d40f7f29f6e1aad8d84b5aa9e44552d2cc17384b209f091276aa", +                "sha256:24995c843eb0ad11a4527b026b4dde3da70e1f2d8806c99b7b4a7cf491612652", +                "sha256:269151951236b0f9a6f04015a9004084a5ab0d5f19b57de779f908621e7d8b75", +                "sha256:4083861b0aa07990b619bd7ddc365eb7fa4b817e99cf5f8d9cf21a42780f6e01", +                "sha256:498b0f36cc7054c1fead3d7fc59d2150f4d5c6c56ba7fb150c013fbc683a8d2d", +                "sha256:4e3e5da80ccbebfff202a67bf900d081906c358ccc3d5e3c8aea42fdfdfd51c1", +                "sha256:6daac9731f172c2a22ade6ed0c00197ee7cc1221aa84cfdf9c31defeb059a907", +                "sha256:715ff2f2df46121071622063fc7543d9b1fd19ebfc4f5c8895af64a77a8c852c", +                "sha256:73d785a950fc82dd2a25897d525d003f6378d1cb23ab305578394694202a58c3", +                "sha256:8c8aaad94455178e3187ab22c8b01a3837f8ee50e09cf31f1ba129eb293ec30b", +                "sha256:8ce678dbaf790dbdb3eba24056d5364fb45944f33553dd5869b7580cdbb83614", +                "sha256:aaee9905aee35ba5905cfb3c62f3e83b3bec7b39413f0a7f19be4e547ea01ebb", +                "sha256:bcd3b13b56ea479b3650b82cabd6b5343a625b0ced5429e4ccad28a8973f301b", +                "sha256:c9e348e02e4d2b4a8b2eedb48210430658df6951fa484e59de33ff773fbd4b41", +                "sha256:d205b1b46085271b4e15f670058ce182bd1199e56b317bf2ec004b6a44f911f6", +                "sha256:d43943ef777f9a1c42bf4e552ba23ac77a6351de620aa9acf64ad54933ad4d34", +                "sha256:d5d33e9e7af3b34a40dc05f498939f0ebf187f07c385fd58d591c533ad8562fe", +                "sha256:fc0fea399acb12edbf8a628ba8d2312f583bdbdb3335635db062fa98cf71fca4", +                "sha256:fe460b922ec15dd205595c9b5b99e2f056fd98ae8f9f56b888e7a17dc2b757e7" +            ], +            "markers": "python_version < '3.8'", +            "version": "==1.4.1" +        },          "unittest-xml-reporting": {              "hashes": [                  "sha256:358bbdaf24a26d904cc1c26ef3078bca7fc81541e0a54c8961693cc96a6f35e0", @@ -883,10 +914,17 @@          },          "virtualenv": {              "hashes": [ -                "sha256:08f3623597ce73b85d6854fb26608a6f39ee9d055c81178dc6583803797f8994", -                "sha256:de2cbdd5926c48d7b84e0300dea9e8f276f61d186e8e49223d71d91250fbaebd" +                "sha256:531b142e300d405bb9faedad4adbeb82b4098b918e35209af2adef3129274aae", +                "sha256:5dd42a9f56307542bddc446cfd10ef6576f11910366a07609fe8d0d88fa8fb7e"              ], -            "version": "==20.0.4" +            "version": "==20.0.5" +        }, +        "zipp": { +            "hashes": [ +                "sha256:12248a63bbdf7548f89cb4c7cda4681e537031eda29c02ea29674bc6854460c2", +                "sha256:7c0f8e91abc0dc07a5068f315c52cb30c66bfbc581e5b50704c8a2f6ebae794a" +            ], +            "version": "==3.0.0"          }      }  } diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 0400ac4d2..874364a6f 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -30,7 +30,7 @@ jobs:        - script: python -m flake8          displayName: 'Run linter' -      - script: BOT_API_KEY=foo BOT_TOKEN=bar WOLFRAM_API_KEY=baz REDDIT_CLIENT_ID=spam REDDIT_SECRET=ham coverage run -m xmlrunner +      - script: BOT_API_KEY=foo BOT_SENTRY_DSN=blah BOT_TOKEN=bar WOLFRAM_API_KEY=baz REDDIT_CLIENT_ID=spam REDDIT_SECRET=ham coverage run -m xmlrunner          displayName: Run tests        - script: coverage report -m && coverage xml -o coverage.xml diff --git a/bot/__init__.py b/bot/__init__.py index 789ace5c0..90ab3c348 100644 --- a/bot/__init__.py +++ b/bot/__init__.py @@ -4,11 +4,8 @@ import sys  from logging import Logger, StreamHandler, handlers  from pathlib import Path -from logmatic import JsonFormatter - - -logging.TRACE = 5 -logging.addLevelName(logging.TRACE, "TRACE") +TRACE_LEVEL = logging.TRACE = 5 +logging.addLevelName(TRACE_LEVEL, "TRACE")  def monkeypatch_trace(self: logging.Logger, msg: str, *args, **kwargs) -> None: @@ -20,75 +17,29 @@ def monkeypatch_trace(self: logging.Logger, msg: str, *args, **kwargs) -> None:      logger.trace("Houston, we have an %s", "interesting problem", exc_info=1)      """ -    if self.isEnabledFor(logging.TRACE): -        self._log(logging.TRACE, msg, args, **kwargs) +    if self.isEnabledFor(TRACE_LEVEL): +        self._log(TRACE_LEVEL, msg, args, **kwargs)  Logger.trace = monkeypatch_trace -# Set up logging -logging_handlers = [] - -# We can't import this yet, so we have to define it ourselves -DEBUG_MODE = True if 'local' in os.environ.get("SITE_URL", "local") else False - -LOG_DIR = Path("logs") -LOG_DIR.mkdir(exist_ok=True) - -if DEBUG_MODE: -    logging_handlers.append(StreamHandler(stream=sys.stdout)) - -    json_handler = logging.FileHandler(filename=Path(LOG_DIR, "log.json"), mode="w") -    json_handler.formatter = JsonFormatter() -    logging_handlers.append(json_handler) -else: - -    logfile = Path(LOG_DIR, "bot.log") -    megabyte = 1048576 - -    filehandler = handlers.RotatingFileHandler(logfile, maxBytes=(megabyte*5), backupCount=7) -    logging_handlers.append(filehandler) - -    json_handler = logging.StreamHandler(stream=sys.stdout) -    json_handler.formatter = JsonFormatter() -    logging_handlers.append(json_handler) - - -logging.basicConfig( -    format="%(asctime)s Bot: | %(name)33s | %(levelname)8s | %(message)s", -    datefmt="%b %d %H:%M:%S", -    level=logging.TRACE if DEBUG_MODE else logging.INFO, -    handlers=logging_handlers -) - -log = logging.getLogger(__name__) - - -for key, value in logging.Logger.manager.loggerDict.items(): -    # Force all existing loggers to the correct level and handlers -    # This happens long before we instantiate our loggers, so -    # those should still have the expected level - -    if key == "bot": -        continue - -    if not isinstance(value, logging.Logger): -        # There might be some logging.PlaceHolder objects in there -        continue +DEBUG_MODE = 'local' in os.environ.get("SITE_URL", "local") -    if DEBUG_MODE: -        value.setLevel(logging.DEBUG) -    else: -        value.setLevel(logging.INFO) +log_format = logging.Formatter("%(asctime)s | %(name)s | %(levelname)s | %(message)s") -    for handler in value.handlers.copy(): -        value.removeHandler(handler) +stream_handler = StreamHandler(stream=sys.stdout) +stream_handler.setFormatter(log_format) -    for handler in logging_handlers: -        value.addHandler(handler) +log_file = Path("logs", "bot.log") +log_file.parent.mkdir(exist_ok=True) +file_handler = handlers.RotatingFileHandler(log_file, maxBytes=5242880, backupCount=7) +file_handler.setFormatter(log_format) +root_log = logging.getLogger() +root_log.setLevel(TRACE_LEVEL if DEBUG_MODE else logging.INFO) +root_log.addHandler(stream_handler) +root_log.addHandler(file_handler) -# Silence irrelevant loggers -logging.getLogger("aio_pika").setLevel(logging.ERROR) -logging.getLogger("discord").setLevel(logging.ERROR) -logging.getLogger("websockets").setLevel(logging.ERROR) +logging.getLogger("discord").setLevel(logging.WARNING) +logging.getLogger("websockets").setLevel(logging.WARNING) +logging.getLogger(__name__).setLevel(TRACE_LEVEL) diff --git a/bot/__main__.py b/bot/__main__.py index 84bc7094b..490163739 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -1,10 +1,23 @@ +import logging +  import discord +import sentry_sdk  from discord.ext.commands import when_mentioned_or +from sentry_sdk.integrations.logging import LoggingIntegration  from bot import patches  from bot.bot import Bot  from bot.constants import Bot as BotConfig, DEBUG_MODE +sentry_logging = LoggingIntegration( +    level=logging.TRACE, +    event_level=logging.WARNING +) + +sentry_sdk.init( +    dsn=BotConfig.sentry_dsn, +    integrations=[sentry_logging] +)  bot = Bot(      command_prefix=when_mentioned_or(BotConfig.prefix), diff --git a/bot/api.py b/bot/api.py index 56db99828..fb126b384 100644 --- a/bot/api.py +++ b/bot/api.py @@ -141,77 +141,3 @@ def loop_is_running() -> bool:      except RuntimeError:          return False      return True - - -class APILoggingHandler(logging.StreamHandler): -    """Site API logging handler.""" - -    def __init__(self, client: APIClient): -        logging.StreamHandler.__init__(self) -        self.client = client - -        # internal batch of shipoff tasks that must not be scheduled -        # on the event loop yet - scheduled when the event loop is ready. -        self.queue = [] - -    async def ship_off(self, payload: dict) -> None: -        """Ship log payload to the logging API.""" -        try: -            await self.client.post('logs', json=payload) -        except ResponseCodeError as err: -            log.warning( -                "Cannot send logging record to the site, got code %d.", -                err.response.status, -                extra={'via_handler': True} -            ) -        except Exception as err: -            log.warning( -                "Cannot send logging record to the site: %r", -                err, -                extra={'via_handler': True} -            ) - -    def emit(self, record: logging.LogRecord) -> None: -        """ -        Determine if a log record should be shipped to the logging API. - -        If the asyncio event loop is not yet running, log records will instead be put in a queue -        which will be consumed once the event loop is running. - -        The following two conditions are set: -            1. Do not log anything below DEBUG (only applies to the monkeypatched `TRACE` level) -            2. Ignore log records originating from this logging handler itself to prevent infinite recursion -        """ -        if ( -                record.levelno >= logging.DEBUG -                and not record.__dict__.get('via_handler') -        ): -            payload = { -                'application': 'bot', -                'logger_name': record.name, -                'level': record.levelname.lower(), -                'module': record.module, -                'line': record.lineno, -                'message': self.format(record) -            } - -            task = self.ship_off(payload) -            if not loop_is_running(): -                self.queue.append(task) -            else: -                asyncio.create_task(task) -                self.schedule_queued_tasks() - -    def schedule_queued_tasks(self) -> None: -        """Consume the queue and schedule the logging of each queued record.""" -        for task in self.queue: -            asyncio.create_task(task) - -        if self.queue: -            log.debug( -                "Scheduled %d pending logging tasks.", -                len(self.queue), -                extra={'via_handler': True} -            ) - -        self.queue.clear() diff --git a/bot/bot.py b/bot/bot.py index 8f808272f..cecee7b68 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -27,8 +27,6 @@ class Bot(commands.Bot):          self.http_session: Optional[aiohttp.ClientSession] = None          self.api_client = api.APIClient(loop=self.loop, connector=self.connector) -        log.addHandler(api.APILoggingHandler(self.api_client)) -      def add_cog(self, cog: commands.Cog) -> None:          """Adds a "cog" to the bot and logs the operation."""          super().add_cog(cog) diff --git a/bot/cogs/defcon.py b/bot/cogs/defcon.py index 3e7350fcc..a0d8fedd5 100644 --- a/bot/cogs/defcon.py +++ b/bot/cogs/defcon.py @@ -76,12 +76,12 @@ class Defcon(Cog):              if data["enabled"]:                  self.enabled = True                  self.days = timedelta(days=data["days"]) -                log.warning(f"DEFCON enabled: {self.days.days} days") +                log.info(f"DEFCON enabled: {self.days.days} days")              else:                  self.enabled = False                  self.days = timedelta(days=0) -                log.warning(f"DEFCON disabled") +                log.info(f"DEFCON disabled")              await self.update_channel_topic() diff --git a/bot/cogs/error_handler.py b/bot/cogs/error_handler.py index 4b08546cc..5af5974cb 100644 --- a/bot/cogs/error_handler.py +++ b/bot/cogs/error_handler.py @@ -1,4 +1,3 @@ -import contextlib  import difflib  import logging @@ -17,6 +16,7 @@ from discord.ext.commands import (      UserInputError,  )  from discord.ext.commands import Cog, Context +from sentry_sdk import push_scope  from bot.api import ResponseCodeError  from bot.bot import Bot @@ -179,10 +179,26 @@ class ErrorHandler(Cog):              f"Sorry, an unexpected error occurred. Please let us know!\n\n"              f"```{e.__class__.__name__}: {e}```"          ) -        log.error( -            f"Error executing command invoked by {ctx.message.author}: {ctx.message.content}" -        ) -        raise e + +        with push_scope() as scope: +            scope.user = { +                "id": ctx.author.id, +                "username": str(ctx.author) +            } + +            scope.set_tag("command", ctx.command.qualified_name) +            scope.set_tag("message_id", ctx.message.id) +            scope.set_tag("channel_id", ctx.channel.id) + +            scope.set_extra("full_message", ctx.message.content) + +            if ctx.guild is not None: +                scope.set_extra( +                    "jump_to", +                    f"https://discordapp.com/channels/{ctx.guild.id}/{ctx.channel.id}/{ctx.message.id}" +                ) + +            log.error(f"Error executing command invoked by {ctx.message.author}: {ctx.message.content}", exc_info=e)  def setup(bot: Bot) -> None: diff --git a/bot/constants.py b/bot/constants.py index fe8e57322..a4c65a1f8 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -193,7 +193,7 @@ class Bot(metaclass=YAMLGetter):      prefix: str      token: str - +    sentry_dsn: str  class Filter(metaclass=YAMLGetter):      section = "filter" diff --git a/bot/rules/attachments.py b/bot/rules/attachments.py index 00bb2a949..8903c385c 100644 --- a/bot/rules/attachments.py +++ b/bot/rules/attachments.py @@ -19,7 +19,7 @@ async def apply(      if total_recent_attachments > config['max']:          return ( -            f"sent {total_recent_attachments} attachments in {config['max']}s", +            f"sent {total_recent_attachments} attachments in {config['interval']}s",              (last_message.author,),              relevant_messages          ) diff --git a/config-default.yml b/config-default.yml index 3345e6f2a..2eaf8ee06 100644 --- a/config-default.yml +++ b/config-default.yml @@ -1,6 +1,7 @@  bot:      prefix:      "!"      token:       !ENV "BOT_TOKEN" +    sentry_dsn:  !ENV "BOT_SENTRY_DSN"      cooldowns:          # Per channel, per tag. diff --git a/tests/bot/rules/__init__.py b/tests/bot/rules/__init__.py index e69de29bb..36c986fe1 100644 --- a/tests/bot/rules/__init__.py +++ b/tests/bot/rules/__init__.py @@ -0,0 +1,76 @@ +import unittest +from abc import ABCMeta, abstractmethod +from typing import Callable, Dict, Iterable, List, NamedTuple, Tuple + +from tests.helpers import MockMessage + + +class DisallowedCase(NamedTuple): +    """Encapsulation for test cases expected to fail.""" +    recent_messages: List[MockMessage] +    culprits: Iterable[str] +    n_violations: int + + +class RuleTest(unittest.TestCase, metaclass=ABCMeta): +    """ +    Abstract class for antispam rule test cases. + +    Tests for specific rules should inherit from `RuleTest` and implement +    `relevant_messages` and `get_report`. Each instance should also set the +    `apply` and `config` attributes as necessary. + +    The execution of test cases can then be delegated to the `run_allowed` +    and `run_disallowed` methods. +    """ + +    apply: Callable  # The tested rule's apply function +    config: Dict[str, int] + +    async def run_allowed(self, cases: Tuple[List[MockMessage], ...]) -> None: +        """Run all `cases` against `self.apply` expecting them to pass.""" +        for recent_messages in cases: +            last_message = recent_messages[0] + +            with self.subTest( +                last_message=last_message, +                recent_messages=recent_messages, +                config=self.config, +            ): +                self.assertIsNone( +                    await self.apply(last_message, recent_messages, self.config) +                ) + +    async def run_disallowed(self, cases: Tuple[DisallowedCase, ...]) -> None: +        """Run all `cases` against `self.apply` expecting them to fail.""" +        for case in cases: +            recent_messages, culprits, n_violations = case +            last_message = recent_messages[0] +            relevant_messages = self.relevant_messages(case) +            desired_output = ( +                self.get_report(case), +                culprits, +                relevant_messages, +            ) + +            with self.subTest( +                last_message=last_message, +                recent_messages=recent_messages, +                relevant_messages=relevant_messages, +                n_violations=n_violations, +                config=self.config, +            ): +                self.assertTupleEqual( +                    await self.apply(last_message, recent_messages, self.config), +                    desired_output, +                ) + +    @abstractmethod +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        """Give expected relevant messages for `case`.""" +        raise NotImplementedError + +    @abstractmethod +    def get_report(self, case: DisallowedCase) -> str: +        """Give expected error report for `case`.""" +        raise NotImplementedError diff --git a/tests/bot/rules/test_attachments.py b/tests/bot/rules/test_attachments.py index d7187f315..e54b4b5b8 100644 --- a/tests/bot/rules/test_attachments.py +++ b/tests/bot/rules/test_attachments.py @@ -1,98 +1,71 @@ -import unittest -from typing import List, NamedTuple, Tuple +from typing import Iterable  from bot.rules import attachments +from tests.bot.rules import DisallowedCase, RuleTest  from tests.helpers import MockMessage, async_test -class Case(NamedTuple): -    recent_messages: List[MockMessage] -    culprit: Tuple[str] -    total_attachments: int - - -def msg(author: str, total_attachments: int) -> MockMessage: +def make_msg(author: str, total_attachments: int) -> MockMessage:      """Builds a message with `total_attachments` attachments."""      return MockMessage(author=author, attachments=list(range(total_attachments))) -class AttachmentRuleTests(unittest.TestCase): +class AttachmentRuleTests(RuleTest):      """Tests applying the `attachments` antispam rule."""      def setUp(self): -        self.config = {"max": 5} +        self.apply = attachments.apply +        self.config = {"max": 5, "interval": 10}      @async_test      async def test_allows_messages_without_too_many_attachments(self):          """Messages without too many attachments are allowed as-is."""          cases = ( -            [msg("bob", 0), msg("bob", 0), msg("bob", 0)], -            [msg("bob", 2), msg("bob", 2)], -            [msg("bob", 2), msg("alice", 2), msg("bob", 2)], +            [make_msg("bob", 0), make_msg("bob", 0), make_msg("bob", 0)], +            [make_msg("bob", 2), make_msg("bob", 2)], +            [make_msg("bob", 2), make_msg("alice", 2), make_msg("bob", 2)],          ) -        for recent_messages in cases: -            last_message = recent_messages[0] - -            with self.subTest( -                last_message=last_message, -                recent_messages=recent_messages, -                config=self.config -            ): -                self.assertIsNone( -                    await attachments.apply(last_message, recent_messages, self.config) -                ) +        await self.run_allowed(cases)      @async_test      async def test_disallows_messages_with_too_many_attachments(self):          """Messages with too many attachments trigger the rule."""          cases = ( -            Case( -                [msg("bob", 4), msg("bob", 0), msg("bob", 6)], +            DisallowedCase( +                [make_msg("bob", 4), make_msg("bob", 0), make_msg("bob", 6)],                  ("bob",), -                10 +                10,              ), -            Case( -                [msg("bob", 4), msg("alice", 6), msg("bob", 2)], +            DisallowedCase( +                [make_msg("bob", 4), make_msg("alice", 6), make_msg("bob", 2)],                  ("bob",), -                6 +                6,              ), -            Case( -                [msg("alice", 6)], +            DisallowedCase( +                [make_msg("alice", 6)],                  ("alice",), -                6 +                6,              ), -            ( -                [msg("alice", 1) for _ in range(6)], +            DisallowedCase( +                [make_msg("alice", 1) for _ in range(6)],                  ("alice",), -                6 +                6,              ),          ) -        for recent_messages, culprit, total_attachments in cases: -            last_message = recent_messages[0] -            relevant_messages = tuple( -                msg -                for msg in recent_messages -                if ( -                    msg.author == last_message.author -                    and len(msg.attachments) > 0 -                ) +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        last_message = case.recent_messages[0] +        return tuple( +            msg +            for msg in case.recent_messages +            if ( +                msg.author == last_message.author +                and len(msg.attachments) > 0              ) +        ) -            with self.subTest( -                last_message=last_message, -                recent_messages=recent_messages, -                relevant_messages=relevant_messages, -                total_attachments=total_attachments, -                config=self.config -            ): -                desired_output = ( -                    f"sent {total_attachments} attachments in {self.config['max']}s", -                    culprit, -                    relevant_messages -                ) -                self.assertTupleEqual( -                    await attachments.apply(last_message, recent_messages, self.config), -                    desired_output -                ) +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} attachments in {self.config['interval']}s" diff --git a/tests/bot/rules/test_burst.py b/tests/bot/rules/test_burst.py new file mode 100644 index 000000000..72f0be0c7 --- /dev/null +++ b/tests/bot/rules/test_burst.py @@ -0,0 +1,56 @@ +from typing import Iterable + +from bot.rules import burst +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + + +def make_msg(author: str) -> MockMessage: +    """ +    Init a MockMessage instance with author set to `author`. + +    This serves as a shorthand / alias to keep the test cases visually clean. +    """ +    return MockMessage(author=author) + + +class BurstRuleTests(RuleTest): +    """Tests the `burst` antispam rule.""" + +    def setUp(self): +        self.apply = burst.apply +        self.config = {"max": 2, "interval": 10} + +    @async_test +    async def test_allows_messages_within_limit(self): +        """Cases which do not violate the rule.""" +        cases = ( +            [make_msg("bob"), make_msg("bob")], +            [make_msg("bob"), make_msg("alice"), make_msg("bob")], +        ) + +        await self.run_allowed(cases) + +    @async_test +    async def test_disallows_messages_beyond_limit(self): +        """Cases where the amount of messages exceeds the limit, triggering the rule.""" +        cases = ( +            DisallowedCase( +                [make_msg("bob"), make_msg("bob"), make_msg("bob")], +                ("bob",), +                3, +            ), +            DisallowedCase( +                [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")], +                ("bob",), +                3, +            ), +        ) + +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        return tuple(msg for msg in case.recent_messages if msg.author in case.culprits) + +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} messages in {self.config['interval']}s" diff --git a/tests/bot/rules/test_burst_shared.py b/tests/bot/rules/test_burst_shared.py new file mode 100644 index 000000000..47367a5f8 --- /dev/null +++ b/tests/bot/rules/test_burst_shared.py @@ -0,0 +1,59 @@ +from typing import Iterable + +from bot.rules import burst_shared +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + + +def make_msg(author: str) -> MockMessage: +    """ +    Init a MockMessage instance with the passed arg. + +    This serves as a shorthand / alias to keep the test cases visually clean. +    """ +    return MockMessage(author=author) + + +class BurstSharedRuleTests(RuleTest): +    """Tests the `burst_shared` antispam rule.""" + +    def setUp(self): +        self.apply = burst_shared.apply +        self.config = {"max": 2, "interval": 10} + +    @async_test +    async def test_allows_messages_within_limit(self): +        """ +        Cases that do not violate the rule. + +        There really isn't more to test here than a single case. +        """ +        cases = ( +            [make_msg("spongebob"), make_msg("patrick")], +        ) + +        await self.run_allowed(cases) + +    @async_test +    async def test_disallows_messages_beyond_limit(self): +        """Cases where the amount of messages exceeds the limit, triggering the rule.""" +        cases = ( +            DisallowedCase( +                [make_msg("bob"), make_msg("bob"), make_msg("bob")], +                {"bob"}, +                3, +            ), +            DisallowedCase( +                [make_msg("bob"), make_msg("bob"), make_msg("alice"), make_msg("bob")], +                {"bob", "alice"}, +                4, +            ), +        ) + +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        return case.recent_messages + +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} messages in {self.config['interval']}s" diff --git a/tests/bot/rules/test_chars.py b/tests/bot/rules/test_chars.py new file mode 100644 index 000000000..7cc36f49e --- /dev/null +++ b/tests/bot/rules/test_chars.py @@ -0,0 +1,66 @@ +from typing import Iterable + +from bot.rules import chars +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + + +def make_msg(author: str, n_chars: int) -> MockMessage: +    """Build a message with arbitrary content of `n_chars` length.""" +    return MockMessage(author=author, content="A" * n_chars) + + +class CharsRuleTests(RuleTest): +    """Tests the `chars` antispam rule.""" + +    def setUp(self): +        self.apply = chars.apply +        self.config = { +            "max": 20,  # Max allowed sum of chars per user +            "interval": 10, +        } + +    @async_test +    async def test_allows_messages_within_limit(self): +        """Cases with a total amount of chars within limit.""" +        cases = ( +            [make_msg("bob", 0)], +            [make_msg("bob", 20)], +            [make_msg("bob", 15), make_msg("alice", 15)], +        ) + +        await self.run_allowed(cases) + +    @async_test +    async def test_disallows_messages_beyond_limit(self): +        """Cases where the total amount of chars exceeds the limit, triggering the rule.""" +        cases = ( +            DisallowedCase( +                [make_msg("bob", 21)], +                ("bob",), +                21, +            ), +            DisallowedCase( +                [make_msg("bob", 15), make_msg("bob", 15)], +                ("bob",), +                30, +            ), +            DisallowedCase( +                [make_msg("alice", 15), make_msg("bob", 20), make_msg("alice", 15)], +                ("alice",), +                30, +            ), +        ) + +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        last_message = case.recent_messages[0] +        return tuple( +            msg +            for msg in case.recent_messages +            if msg.author == last_message.author +        ) + +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} characters in {self.config['interval']}s" diff --git a/tests/bot/rules/test_discord_emojis.py b/tests/bot/rules/test_discord_emojis.py new file mode 100644 index 000000000..0239b0b00 --- /dev/null +++ b/tests/bot/rules/test_discord_emojis.py @@ -0,0 +1,54 @@ +from typing import Iterable + +from bot.rules import discord_emojis +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + +discord_emoji = "<:abcd:1234>"  # Discord emojis follow the format <:name:id> + + +def make_msg(author: str, n_emojis: int) -> MockMessage: +    """Build a MockMessage instance with content containing `n_emojis` arbitrary emojis.""" +    return MockMessage(author=author, content=discord_emoji * n_emojis) + + +class DiscordEmojisRuleTests(RuleTest): +    """Tests for the `discord_emojis` antispam rule.""" + +    def setUp(self): +        self.apply = discord_emojis.apply +        self.config = {"max": 2, "interval": 10} + +    @async_test +    async def test_allows_messages_within_limit(self): +        """Cases with a total amount of discord emojis within limit.""" +        cases = ( +            [make_msg("bob", 2)], +            [make_msg("alice", 1), make_msg("bob", 2), make_msg("alice", 1)], +        ) + +        await self.run_allowed(cases) + +    @async_test +    async def test_disallows_messages_beyond_limit(self): +        """Cases with more than the allowed amount of discord emojis.""" +        cases = ( +            DisallowedCase( +                [make_msg("bob", 3)], +                ("bob",), +                3, +            ), +            DisallowedCase( +                [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)], +                ("alice",), +                4, +            ), +        ) + +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        return tuple(msg for msg in case.recent_messages if msg.author in case.culprits) + +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} emojis in {self.config['interval']}s" diff --git a/tests/bot/rules/test_duplicates.py b/tests/bot/rules/test_duplicates.py new file mode 100644 index 000000000..59e0fb6ef --- /dev/null +++ b/tests/bot/rules/test_duplicates.py @@ -0,0 +1,66 @@ +from typing import Iterable + +from bot.rules import duplicates +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + + +def make_msg(author: str, content: str) -> MockMessage: +    """Give a MockMessage instance with `author` and `content` attrs.""" +    return MockMessage(author=author, content=content) + + +class DuplicatesRuleTests(RuleTest): +    """Tests the `duplicates` antispam rule.""" + +    def setUp(self): +        self.apply = duplicates.apply +        self.config = {"max": 2, "interval": 10} + +    @async_test +    async def test_allows_messages_within_limit(self): +        """Cases which do not violate the rule.""" +        cases = ( +            [make_msg("alice", "A"), make_msg("alice", "A")], +            [make_msg("alice", "A"), make_msg("alice", "B"), make_msg("alice", "C")],  # Non-duplicate +            [make_msg("alice", "A"), make_msg("bob", "A"), make_msg("alice", "A")],  # Different author +        ) + +        await self.run_allowed(cases) + +    @async_test +    async def test_disallows_messages_beyond_limit(self): +        """Cases with too many duplicate messages from the same author.""" +        cases = ( +            DisallowedCase( +                [make_msg("alice", "A"), make_msg("alice", "A"), make_msg("alice", "A")], +                ("alice",), +                3, +            ), +            DisallowedCase( +                [make_msg("bob", "A"), make_msg("alice", "A"), make_msg("bob", "A"), make_msg("bob", "A")], +                ("bob",), +                3,  # 4 duplicate messages, but only 3 from bob +            ), +            DisallowedCase( +                [make_msg("bob", "A"), make_msg("bob", "B"), make_msg("bob", "A"), make_msg("bob", "A")], +                ("bob",), +                3,  # 4 message from bob, but only 3 duplicates +            ), +        ) + +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        last_message = case.recent_messages[0] +        return tuple( +            msg +            for msg in case.recent_messages +            if ( +                msg.author == last_message.author +                and msg.content == last_message.content +            ) +        ) + +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} duplicated messages in {self.config['interval']}s" diff --git a/tests/bot/rules/test_links.py b/tests/bot/rules/test_links.py index 02a5d5501..3c3f90e5f 100644 --- a/tests/bot/rules/test_links.py +++ b/tests/bot/rules/test_links.py @@ -1,26 +1,21 @@ -import unittest -from typing import List, NamedTuple, Tuple +from typing import Iterable  from bot.rules import links +from tests.bot.rules import DisallowedCase, RuleTest  from tests.helpers import MockMessage, async_test -class Case(NamedTuple): -    recent_messages: List[MockMessage] -    culprit: Tuple[str] -    total_links: int - - -def msg(author: str, total_links: int) -> MockMessage: +def make_msg(author: str, total_links: int) -> MockMessage:      """Makes a message with `total_links` links."""      content = " ".join(["https://pydis.com"] * total_links)      return MockMessage(author=author, content=content) -class LinksTests(unittest.TestCase): +class LinksTests(RuleTest):      """Tests applying the `links` rule."""      def setUp(self): +        self.apply = links.apply          self.config = {              "max": 2,              "interval": 10 @@ -30,68 +25,45 @@ class LinksTests(unittest.TestCase):      async def test_links_within_limit(self):          """Messages with an allowed amount of links."""          cases = ( -            [msg("bob", 0)], -            [msg("bob", 2)], -            [msg("bob", 3)],  # Filter only applies if len(messages_with_links) > 1 -            [msg("bob", 1), msg("bob", 1)], -            [msg("bob", 2), msg("alice", 2)]  # Only messages from latest author count +            [make_msg("bob", 0)], +            [make_msg("bob", 2)], +            [make_msg("bob", 3)],  # Filter only applies if len(messages_with_links) > 1 +            [make_msg("bob", 1), make_msg("bob", 1)], +            [make_msg("bob", 2), make_msg("alice", 2)]  # Only messages from latest author count          ) -        for recent_messages in cases: -            last_message = recent_messages[0] - -            with self.subTest( -                last_message=last_message, -                recent_messages=recent_messages, -                config=self.config -            ): -                self.assertIsNone( -                    await links.apply(last_message, recent_messages, self.config) -                ) +        await self.run_allowed(cases)      @async_test      async def test_links_exceeding_limit(self):          """Messages with a a higher than allowed amount of links."""          cases = ( -            Case( -                [msg("bob", 1), msg("bob", 2)], +            DisallowedCase( +                [make_msg("bob", 1), make_msg("bob", 2)],                  ("bob",),                  3              ), -            Case( -                [msg("alice", 1), msg("alice", 1), msg("alice", 1)], +            DisallowedCase( +                [make_msg("alice", 1), make_msg("alice", 1), make_msg("alice", 1)],                  ("alice",),                  3              ), -            Case( -                [msg("alice", 2), msg("bob", 3), msg("alice", 1)], +            DisallowedCase( +                [make_msg("alice", 2), make_msg("bob", 3), make_msg("alice", 1)],                  ("alice",),                  3              )          ) -        for recent_messages, culprit, total_links in cases: -            last_message = recent_messages[0] -            relevant_messages = tuple( -                msg -                for msg in recent_messages -                if msg.author == last_message.author -            ) +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        last_message = case.recent_messages[0] +        return tuple( +            msg +            for msg in case.recent_messages +            if msg.author == last_message.author +        ) -            with self.subTest( -                last_message=last_message, -                recent_messages=recent_messages, -                relevant_messages=relevant_messages, -                culprit=culprit, -                total_links=total_links, -                config=self.config -            ): -                desired_output = ( -                    f"sent {total_links} links in {self.config['interval']}s", -                    culprit, -                    relevant_messages -                ) -                self.assertTupleEqual( -                    await links.apply(last_message, recent_messages, self.config), -                    desired_output -                ) +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} links in {self.config['interval']}s" diff --git a/tests/bot/rules/test_mentions.py b/tests/bot/rules/test_mentions.py index ad49ead32..ebcdabac6 100644 --- a/tests/bot/rules/test_mentions.py +++ b/tests/bot/rules/test_mentions.py @@ -1,95 +1,67 @@ -import unittest -from typing import List, NamedTuple, Tuple +from typing import Iterable  from bot.rules import mentions +from tests.bot.rules import DisallowedCase, RuleTest  from tests.helpers import MockMessage, async_test -class Case(NamedTuple): -    recent_messages: List[MockMessage] -    culprit: Tuple[str] -    total_mentions: int - - -def msg(author: str, total_mentions: int) -> MockMessage: +def make_msg(author: str, total_mentions: int) -> MockMessage:      """Makes a message with `total_mentions` mentions."""      return MockMessage(author=author, mentions=list(range(total_mentions))) -class TestMentions(unittest.TestCase): +class TestMentions(RuleTest):      """Tests applying the `mentions` antispam rule."""      def setUp(self): +        self.apply = mentions.apply          self.config = {              "max": 2, -            "interval": 10 +            "interval": 10,          }      @async_test      async def test_mentions_within_limit(self):          """Messages with an allowed amount of mentions."""          cases = ( -            [msg("bob", 0)], -            [msg("bob", 2)], -            [msg("bob", 1), msg("bob", 1)], -            [msg("bob", 1), msg("alice", 2)] +            [make_msg("bob", 0)], +            [make_msg("bob", 2)], +            [make_msg("bob", 1), make_msg("bob", 1)], +            [make_msg("bob", 1), make_msg("alice", 2)],          ) -        for recent_messages in cases: -            last_message = recent_messages[0] - -            with self.subTest( -                last_message=last_message, -                recent_messages=recent_messages, -                config=self.config -            ): -                self.assertIsNone( -                    await mentions.apply(last_message, recent_messages, self.config) -                ) +        await self.run_allowed(cases)      @async_test      async def test_mentions_exceeding_limit(self):          """Messages with a higher than allowed amount of mentions."""          cases = ( -            Case( -                [msg("bob", 3)], +            DisallowedCase( +                [make_msg("bob", 3)],                  ("bob",), -                3 +                3,              ), -            Case( -                [msg("alice", 2), msg("alice", 0), msg("alice", 1)], +            DisallowedCase( +                [make_msg("alice", 2), make_msg("alice", 0), make_msg("alice", 1)],                  ("alice",), -                3 +                3,              ), -            Case( -                [msg("bob", 2), msg("alice", 3), msg("bob", 2)], +            DisallowedCase( +                [make_msg("bob", 2), make_msg("alice", 3), make_msg("bob", 2)],                  ("bob",), -                4 +                4,              )          ) -        for recent_messages, culprit, total_mentions in cases: -            last_message = recent_messages[0] -            relevant_messages = tuple( -                msg -                for msg in recent_messages -                if msg.author == last_message.author -            ) +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        last_message = case.recent_messages[0] +        return tuple( +            msg +            for msg in case.recent_messages +            if msg.author == last_message.author +        ) -            with self.subTest( -                last_message=last_message, -                recent_messages=recent_messages, -                relevant_messages=relevant_messages, -                culprit=culprit, -                total_mentions=total_mentions, -                cofig=self.config -            ): -                desired_output = ( -                    f"sent {total_mentions} mentions in {self.config['interval']}s", -                    culprit, -                    relevant_messages -                ) -                self.assertTupleEqual( -                    await mentions.apply(last_message, recent_messages, self.config), -                    desired_output -                ) +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} mentions in {self.config['interval']}s" diff --git a/tests/bot/rules/test_newlines.py b/tests/bot/rules/test_newlines.py new file mode 100644 index 000000000..d61c4609d --- /dev/null +++ b/tests/bot/rules/test_newlines.py @@ -0,0 +1,105 @@ +from typing import Iterable, List + +from bot.rules import newlines +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + + +def make_msg(author: str, newline_groups: List[int]) -> MockMessage: +    """Init a MockMessage instance with `author` and content configured by `newline_groups". + +    Configure content by passing a list of ints, where each int `n` will generate +    a separate group of `n` newlines. + +    Example: +        newline_groups=[3, 1, 2] -> content="\n\n\n \n \n\n" +    """ +    content = " ".join("\n" * n for n in newline_groups) +    return MockMessage(author=author, content=content) + + +class TotalNewlinesRuleTests(RuleTest): +    """Tests the `newlines` antispam rule against allowed cases and total newline count violations.""" + +    def setUp(self): +        self.apply = newlines.apply +        self.config = { +            "max": 5,  # Max sum of newlines in relevant messages +            "max_consecutive": 3,  # Max newlines in one group, in one message +            "interval": 10, +        } + +    @async_test +    async def test_allows_messages_within_limit(self): +        """Cases which do not violate the rule.""" +        cases = ( +            [make_msg("alice", [])],  # Single message with no newlines +            [make_msg("alice", [1, 2]), make_msg("alice", [1, 1])],  # 5 newlines in 2 messages +            [make_msg("alice", [2, 2, 1]), make_msg("bob", [2, 3])],  # 5 newlines from each author +            [make_msg("bob", [1]), make_msg("alice", [5])],  # Alice breaks the rule, but only bob is relevant +        ) + +        await self.run_allowed(cases) + +    @async_test +    async def test_disallows_messages_total(self): +        """Cases which violate the rule by having too many newlines in total.""" +        cases = ( +            DisallowedCase(  # Alice sends a total of 6 newlines (disallowed) +                [make_msg("alice", [2, 2]), make_msg("alice", [2])], +                ("alice",), +                6, +            ), +            DisallowedCase(  # Here we test that only alice's newlines count in the sum +                [make_msg("alice", [2, 2]), make_msg("bob", [3]), make_msg("alice", [3])], +                ("alice",), +                7, +            ), +        ) + +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        last_author = case.recent_messages[0].author +        return tuple(msg for msg in case.recent_messages if msg.author == last_author) + +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} newlines in {self.config['interval']}s" + + +class GroupNewlinesRuleTests(RuleTest): +    """ +    Tests the `newlines` antispam rule against max consecutive newline violations. + +    As these violations yield a different error report, they require a different +    `get_report` implementation. +    """ + +    def setUp(self): +        self.apply = newlines.apply +        self.config = {"max": 5, "max_consecutive": 3, "interval": 10} + +    @async_test +    async def test_disallows_messages_consecutive(self): +        """Cases which violate the rule due to having too many consecutive newlines.""" +        cases = ( +            DisallowedCase(  # Bob sends a group of newlines too large +                [make_msg("bob", [4])], +                ("bob",), +                4, +            ), +            DisallowedCase(  # Alice sends 5 in total (allowed), but 4 in one group (disallowed) +                [make_msg("alice", [1]), make_msg("alice", [4])], +                ("alice",), +                4, +            ), +        ) + +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        last_author = case.recent_messages[0].author +        return tuple(msg for msg in case.recent_messages if msg.author == last_author) + +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} consecutive newlines in {self.config['interval']}s" diff --git a/tests/bot/rules/test_role_mentions.py b/tests/bot/rules/test_role_mentions.py new file mode 100644 index 000000000..b339cccf7 --- /dev/null +++ b/tests/bot/rules/test_role_mentions.py @@ -0,0 +1,57 @@ +from typing import Iterable + +from bot.rules import role_mentions +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockMessage, async_test + + +def make_msg(author: str, n_mentions: int) -> MockMessage: +    """Build a MockMessage instance with `n_mentions` role mentions.""" +    return MockMessage(author=author, role_mentions=[None] * n_mentions) + + +class RoleMentionsRuleTests(RuleTest): +    """Tests for the `role_mentions` antispam rule.""" + +    def setUp(self): +        self.apply = role_mentions.apply +        self.config = {"max": 2, "interval": 10} + +    @async_test +    async def test_allows_messages_within_limit(self): +        """Cases with a total amount of role mentions within limit.""" +        cases = ( +            [make_msg("bob", 2)], +            [make_msg("bob", 1), make_msg("alice", 1), make_msg("bob", 1)], +        ) + +        await self.run_allowed(cases) + +    @async_test +    async def test_disallows_messages_beyond_limit(self): +        """Cases with more than the allowed amount of role mentions.""" +        cases = ( +            DisallowedCase( +                [make_msg("bob", 3)], +                ("bob",), +                3, +            ), +            DisallowedCase( +                [make_msg("alice", 2), make_msg("bob", 2), make_msg("alice", 2)], +                ("alice",), +                4, +            ), +        ) + +        await self.run_disallowed(cases) + +    def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: +        last_message = case.recent_messages[0] +        return tuple( +            msg +            for msg in case.recent_messages +            if msg.author == last_message.author +        ) + +    def get_report(self, case: DisallowedCase) -> str: +        return f"sent {case.n_violations} role mentions in {self.config['interval']}s" diff --git a/tests/bot/test_api.py b/tests/bot/test_api.py index 5a88adc5c..bdfcc73e4 100644 --- a/tests/bot/test_api.py +++ b/tests/bot/test_api.py @@ -1,9 +1,7 @@ -import logging  import unittest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock  from bot import api -from tests.base import LoggingTestCase  from tests.helpers import async_test @@ -34,7 +32,7 @@ class APIClientTests(unittest.TestCase):          self.assertEqual(error.response_text, "")          self.assertIs(error.response, self.error_api_response) -    def test_responde_code_error_string_representation_default_initialization(self): +    def test_response_code_error_string_representation_default_initialization(self):          """Test the string representation of `ResponseCodeError` initialized without text or json."""          error = api.ResponseCodeError(response=self.error_api_response)          self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: ") @@ -76,61 +74,3 @@ class APIClientTests(unittest.TestCase):              response_text=text_data          )          self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: {text_data}") - - -class LoggingHandlerTests(LoggingTestCase): -    """Tests the bot's API Log Handler.""" - -    @classmethod -    def setUpClass(cls): -        cls.debug_log_record = logging.LogRecord( -            name='my.logger', level=logging.DEBUG, -            pathname='my/logger.py', lineno=666, -            msg="Lemon wins", args=(), -            exc_info=None -        ) - -        cls.trace_log_record = logging.LogRecord( -            name='my.logger', level=logging.TRACE, -            pathname='my/logger.py', lineno=666, -            msg="This will not be logged", args=(), -            exc_info=None -        ) - -    def setUp(self): -        self.log_handler = api.APILoggingHandler(None) - -    def test_emit_appends_to_queue_with_stopped_event_loop(self): -        """Test if `APILoggingHandler.emit` appends to queue when the event loop is not running.""" -        with patch("bot.api.APILoggingHandler.ship_off") as ship_off: -            # Patch `ship_off` to ease testing against the return value of this coroutine. -            ship_off.return_value = 42 -            self.log_handler.emit(self.debug_log_record) - -        self.assertListEqual(self.log_handler.queue, [42]) - -    def test_emit_ignores_less_than_debug(self): -        """`APILoggingHandler.emit` should not queue logs with a log level lower than DEBUG.""" -        self.log_handler.emit(self.trace_log_record) -        self.assertListEqual(self.log_handler.queue, []) - -    def test_schedule_queued_tasks_for_empty_queue(self): -        """`APILoggingHandler` should not schedule anything when the queue is empty.""" -        with self.assertNotLogs(level=logging.DEBUG): -            self.log_handler.schedule_queued_tasks() - -    def test_schedule_queued_tasks_for_nonempty_queue(self): -        """`APILoggingHandler` should schedule logs when the queue is not empty.""" -        log = logging.getLogger("bot.api") - -        with self.assertLogs(logger=log, level=logging.DEBUG) as logs, patch('asyncio.create_task') as create_task: -            self.log_handler.queue = [555] -            self.log_handler.schedule_queued_tasks() -            self.assertListEqual(self.log_handler.queue, []) -            create_task.assert_called_once_with(555) - -            [record] = logs.records -            self.assertEqual(record.message, "Scheduled 1 pending logging tasks.") -            self.assertEqual(record.levelno, logging.DEBUG) -            self.assertEqual(record.name, 'bot.api') -            self.assertIn('via_handler', record.__dict__) @@ -3,7 +3,7 @@ max-line-length=120  docstring-convention=all  import-order-style=pycharm  application_import_names=bot,tests -exclude=.cache,.venv,constants.py +exclude=.cache,.venv,.git,constants.py  ignore=      B311,W503,E226,S311,T000      # Missing Docstrings @@ -15,5 +15,5 @@ ignore=      # Docstring Content      D400,D401,D402,D404,D405,D406,D407,D408,D409,D410,D411,D412,D413,D414,D416,D417      # Type Annotations -    TYP002,TYP003,TYP101,TYP102,TYP204,TYP206 -per-file-ignores=tests/*:D,TYP +    ANN002,ANN003,ANN101,ANN102,ANN204,ANN206 +per-file-ignores=tests/*:D,ANN  |