diff options
51 files changed, 2468 insertions, 1276 deletions
| diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 000000000..d572bd705 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,5 @@ +[run] +branch = true +source = +    bot +    tests diff --git a/.gitignore b/.gitignore index a191523b6..fb3156ab1 100644 --- a/.gitignore +++ b/.gitignore @@ -114,8 +114,8 @@ log.*  # Custom user configuration  config.yml -# JUnit XML reports from pytest -junit.xml +# xmlrunner unittest XML reports +TEST-**.xml  # Mac OS .DS_Store, which is a file that stores custom attributes of its containing folder  .DS_Store @@ -21,6 +21,7 @@ more_itertools = "~=7.2"  urllib3 = ">=1.24.2,<1.25"  [dev-packages] +coverage = "~=4.5"  flake8 = "~=3.7"  flake8-annotations = "~=1.1"  flake8-bugbear = "~=19.8" @@ -31,9 +32,8 @@ flake8-tidy-imports = "~=2.0"  flake8-todo = "~=0.7"  pre-commit = "~=1.18"  safety = "~=1.8" +unittest-xml-reporting = "~=2.5"  dodgy = "~=0.1" -pytest = "~=5.1" -pytest-cov = "~=2.7"  [requires]  python_version = "3.7" @@ -44,3 +44,5 @@ lint = "python -m flake8"  precommit = "pre-commit install"  build = "docker build -t pythondiscord/bot:latest -f Dockerfile ."  push = "docker push pythondiscord/bot:latest" +test = "coverage run -m unittest" +report = "coverage report" diff --git a/Pipfile.lock b/Pipfile.lock index 4e6b4eaf8..95955ff89 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@  {      "_meta": {          "hash": { -            "sha256": "c2537cc3c5b0886d0b38f9b48f4f4b93e1e74d925454aa71a2189bddedadde42" +            "sha256": "c27d699b4aeeed204dee41f924f682ae2a670add8549a8826e58776594370582"          },          "pipfile-spec": 6,          "requires": { @@ -18,11 +18,11 @@      "default": {          "aio-pika": {              "hashes": [ -                "sha256:29f27a8092169924c9eefb0c5e428d216706618dc9caa75ddb7759638e16cf26", -                "sha256:4f77ba9b6e7bc27fc88c49638bc3657ae5d4a2539e17fa0c2b25b370547b1b50" +                "sha256:1dcec3e3e3309e277511dc0d7d157676d0165c174a6a745673fc9cf0510db8f0", +                "sha256:dd5a23ca26a4872ee73bd107e4c545bace572cdec2a574aeb61f4062c7774b2a"              ],              "index": "pypi", -            "version": "==6.1.2" +            "version": "==6.1.3"          },          "aiodns": {              "hashes": [ @@ -83,10 +83,10 @@          },          "attrs": {              "hashes": [ -                "sha256:69c0dbf2ed392de1cb5ec704444b08a5ef81680a61cb899dc08127123af36a79", -                "sha256:f0b870f674851ecbfbbbd364d6b5cbdff9dcedbc7f3f5e18a6891057f21fe399" +                "sha256:ec20e7a4825331c1b5ebf261d111e16fa9612c1f7a5e1f884f12bd53a664dfd2", +                "sha256:f913492e1663d3c36f502e5e9ba6cd13cf19d7fab50aa13239e420fef95e1396"              ], -            "version": "==19.1.0" +            "version": "==19.2.0"          },          "babel": {              "hashes": [ @@ -97,11 +97,11 @@          },          "beautifulsoup4": {              "hashes": [ -                "sha256:05668158c7b85b791c5abde53e50265e16f98ad601c402ba44d70f96c4159612", -                "sha256:25288c9e176f354bf277c0a10aa96c782a6a18a17122dba2e8cec4a97e03343b", -                "sha256:f040590be10520f2ea4c2ae8c3dae441c7cfff5308ec9d58a0ec0c1b8f81d469" +                "sha256:5279c36b4b2ec2cb4298d723791467e3000e5384a43ea0cdf5d45207c7e97169", +                "sha256:6135db2ba678168c07950f9a16c4031822c6f4aec75a65e0a97bc5ca09789931", +                "sha256:dcdef580e18a76d54002088602eba453eec38ebbcafafeaabd8cab12b6155d57"              ], -            "version": "==4.8.0" +            "version": "==4.8.1"          },          "certifi": {              "hashes": [ @@ -150,13 +150,6 @@              ],              "version": "==3.0.4"          }, -        "colorama": { -            "hashes": [ -                "sha256:05eed71e2e327246ad6b38c540c4a3117230b19679b875190486ddd2d721422d", -                "sha256:f8ac84de7840f5b9c4e3347b3c1eaa50f7e49c2b07596221daec5edaabbd7c48" -            ], -            "version": "==0.4.1" -        },          "deepdiff": {              "hashes": [                  "sha256:1123762580af0904621136d117c8397392a244d3ff0fa0a50de57a7939582476", @@ -204,10 +197,10 @@          },          "jinja2": {              "hashes": [ -                "sha256:065c4f02ebe7f7cf559e49ee5a95fb800a9e4528727aec6f24402a5374c65013", -                "sha256:14dd6caf1527abb21f08f86c784eac40853ba93edb79552aa1e4b8aef1b61c7b" +                "sha256:74320bb91f31270f9551d46522e33af46a80c3d619f4a4bf42b3164d30b5911f", +                "sha256:9fe95f19286cfefaa917656583d020be14e7859c6b0252588391e47db34527de"              ], -            "version": "==2.10.1" +            "version": "==2.10.3"          },          "jsonpickle": {              "hashes": [ @@ -407,10 +400,10 @@          },          "pytz": {              "hashes": [ -                "sha256:26c0b32e437e54a18161324a2fca3c4b9846b74a8dccddd843113109e1116b32", -                "sha256:c894d57500a4cd2d5c71114aaab77dbab5eabd9022308ce5ac9bb93a60a6f0c7" +                "sha256:1c557d7d0e871de1f5ccd5833f60fb2550652da6be2693c1e02300743d21500d", +                "sha256:b02c06db6cf09c12dd25137e563b31700d3b80fcc4ad23abb7a315f2789819be"              ], -            "version": "==2019.2" +            "version": "==2019.3"          },          "pyyaml": {              "hashes": [ @@ -448,9 +441,10 @@          },          "snowballstemmer": {              "hashes": [ -                "sha256:713e53b79cbcf97bc5245a06080a33d54a77e7cce2f789c835a143bcdb5c033e" +                "sha256:209f257d7533fdb3cb73bdbd24f436239ca3b2fa67d56f6ff88e86be08cc5ef0", +                "sha256:df3bac3df4c2c01363f3dd2cfa78cce2840a79b9f1c2d2de9ce8d31683992f52"              ], -            "version": "==1.9.1" +            "version": "==2.0.0"          },          "soupsieve": {              "hashes": [ @@ -568,19 +562,12 @@              ],              "version": "==1.3.0"          }, -        "atomicwrites": { -            "hashes": [ -                "sha256:03472c30eb2c5d1ba9227e4c2ca66ab8287fbfbbda3888aa93dc2e28fc6811b4", -                "sha256:75a9445bac02d8d058d5e1fe689654ba5a6556a1dfd8ce6ec55a0ed79866cfa6" -            ], -            "version": "==1.3.0" -        },          "attrs": {              "hashes": [ -                "sha256:69c0dbf2ed392de1cb5ec704444b08a5ef81680a61cb899dc08127123af36a79", -                "sha256:f0b870f674851ecbfbbbd364d6b5cbdff9dcedbc7f3f5e18a6891057f21fe399" +                "sha256:ec20e7a4825331c1b5ebf261d111e16fa9612c1f7a5e1f884f12bd53a664dfd2", +                "sha256:f913492e1663d3c36f502e5e9ba6cd13cf19d7fab50aa13239e420fef95e1396"              ], -            "version": "==19.1.0" +            "version": "==19.2.0"          },          "certifi": {              "hashes": [ @@ -610,13 +597,6 @@              ],              "version": "==7.0"          }, -        "colorama": { -            "hashes": [ -                "sha256:05eed71e2e327246ad6b38c540c4a3117230b19679b875190486ddd2d721422d", -                "sha256:f8ac84de7840f5b9c4e3347b3c1eaa50f7e49c2b07596221daec5edaabbd7c48" -            ], -            "version": "==0.4.1" -        },          "coverage": {              "hashes": [                  "sha256:08907593569fe59baca0bf152c43f3863201efb6113ecb38ce7e97ce339805a6", @@ -652,6 +632,7 @@                  "sha256:fa964bae817babece5aa2e8c1af841bebb6d0b9add8e637548809d040443fee0",                  "sha256:ff37757e068ae606659c28c3bd0d923f9d29a85de79bf25b2b34b148473b5025"              ], +            "index": "pypi",              "version": "==4.5.4"          },          "dodgy": { @@ -701,11 +682,11 @@          },          "flake8-docstrings": {              "hashes": [ -                "sha256:1666dd069c9c457ee57e80af3c1a6b37b00cc1801c6fde88e455131bb2e186cd", -                "sha256:9c0db5a79a1affd70fdf53b8765c8a26bf968e59e0252d7f2fc546b41c0cda06" +                "sha256:3d5a31c7ec6b7367ea6506a87ec293b94a0a46c0bce2bb4975b7f1d09b6f3717", +                "sha256:a256ba91bc52307bef1de59e2a009c3cf61c3d0952dbe035d6ff7208940c2edc"              ],              "index": "pypi", -            "version": "==1.4.0" +            "version": "==1.5.0"          },          "flake8-import-order": {              "hashes": [ @@ -757,7 +738,6 @@                  "sha256:aa18d7378b00b40847790e7c27e11673d7fed219354109d0e7b9e5b25dc3ad26",                  "sha256:d5f18a79777f3aa179c145737780282e27b508fc8fd688cb17c7a813e8bd39af"              ], -            "markers": "python_version < '3.8'",              "version": "==0.23"          },          "mccabe": { @@ -788,13 +768,6 @@              ],              "version": "==19.2"          }, -        "pluggy": { -            "hashes": [ -                "sha256:0db4b7601aae1d35b4a033282da476845aa19185c1e6964b25cf324b5e4ec3e6", -                "sha256:fa5fa1622fa6dd5c030e9cad086fa19ef6a0cf6d7a2d12318e10cb49d6d68f34" -            ], -            "version": "==0.13.0" -        },          "pre-commit": {              "hashes": [                  "sha256:1d3c0587bda7c4e537a46c27f2c84aa006acc18facf9970bf947df596ce91f3f", @@ -803,13 +776,6 @@              "index": "pypi",              "version": "==1.18.3"          }, -        "py": { -            "hashes": [ -                "sha256:64f65755aee5b381cea27766a3a147c3f15b9b6b9ac88676de66ba2ae36793fa", -                "sha256:dc639b046a6e2cff5bbe40194ad65936d6ba360b52b3c3fe1d08a82dd50b5e53" -            ], -            "version": "==1.8.0" -        },          "pycodestyle": {              "hashes": [                  "sha256:95a2219d12372f05704562a14ec30bc76b05a5b297b21a5dfe3f6fac3491ae56", @@ -838,22 +804,6 @@              ],              "version": "==2.4.2"          }, -        "pytest": { -            "hashes": [ -                "sha256:813b99704b22c7d377bbd756ebe56c35252bb710937b46f207100e843440b3c2", -                "sha256:cc6620b96bc667a0c8d4fa592a8c9c94178a1bd6cc799dbb057dfd9286d31a31" -            ], -            "index": "pypi", -            "version": "==5.1.3" -        }, -        "pytest-cov": { -            "hashes": [ -                "sha256:2b097cde81a302e1047331b48cadacf23577e431b61e9c6f49a1170bbe3d3da6", -                "sha256:e00ea4fdde970725482f1f35630d12f074e121a23801aabf2ae154ec6bdd343a" -            ], -            "index": "pypi", -            "version": "==2.7.1" -        },          "pyyaml": {              "hashes": [                  "sha256:0113bc0ec2ad727182326b61326afa3d1d8280ae1122493553fd6f4397f33df9", @@ -898,9 +848,10 @@          },          "snowballstemmer": {              "hashes": [ -                "sha256:713e53b79cbcf97bc5245a06080a33d54a77e7cce2f789c835a143bcdb5c033e" +                "sha256:209f257d7533fdb3cb73bdbd24f436239ca3b2fa67d56f6ff88e86be08cc5ef0", +                "sha256:df3bac3df4c2c01363f3dd2cfa78cce2840a79b9f1c2d2de9ce8d31683992f52"              ], -            "version": "==1.9.1" +            "version": "==2.0.0"          },          "toml": {              "hashes": [ @@ -929,6 +880,14 @@              ],              "version": "==1.4.0"          }, +        "unittest-xml-reporting": { +            "hashes": [ +                "sha256:140982e4b58e4052d9ecb775525b246a96bfc1fc26097806e05ea06e9166dd6c", +                "sha256:d1fbc7a1b6c6680ccfe75b5e9701e5431c646970de049e687b4bb35ba4325d72" +            ], +            "index": "pypi", +            "version": "==2.5.1" +        },          "urllib3": {              "hashes": [                  "sha256:2393a695cd12afedd0dcb26fe5d50d0cf248e5a66f75dbd89a3d4eb333a61af4", @@ -944,13 +903,6 @@              ],              "version": "==16.7.5"          }, -        "wcwidth": { -            "hashes": [ -                "sha256:3df37372226d6e63e1b1e1eda15c594bca98a22d33a23832a90998faa96bc65e", -                "sha256:f4ebe71925af7b40a864553f761ed559b43544f8f71746c2d756c7fe788ade7c" -            ], -            "version": "==0.1.7" -        },          "zipp": {              "hashes": [                  "sha256:3718b1cbcd963c7d4c5511a8240812904164b7f381b647143a89d3b98f9bcd8e", diff --git a/azure-pipelines.yml b/azure-pipelines.yml index c22bac089..da3b06201 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -30,9 +30,12 @@ jobs:        - script: python -m flake8          displayName: 'Run linter' -      - script: BOT_API_KEY=foo BOT_TOKEN=bar WOLFRAM_API_KEY=baz python -m pytest --junitxml=junit.xml --cov=bot --cov-branch --cov-report=term --cov-report=xml tests +      - script: BOT_API_KEY=foo BOT_TOKEN=bar WOLFRAM_API_KEY=baz coverage run -m xmlrunner          displayName: Run tests +      - script: coverage report -m && coverage xml -o coverage.xml +        displayName: Generate test coverage report +        - task: PublishCodeCoverageResults@1          displayName: 'Publish Coverage Results'          condition: succeededOrFailed() @@ -41,11 +44,11 @@ jobs:            summaryFileLocation: coverage.xml        - task: PublishTestResults@2 -        displayName: 'Publish Test Results'          condition: succeededOrFailed() +        displayName: 'Publish Test Results'          inputs: -          testResultsFiles: junit.xml -          testRunTitle: 'Bot Test results' +          testResultsFiles: '**/TEST-*.xml' +          testRunTitle: 'Bot Test Results'    - job: build      displayName: 'Build & Push Container' diff --git a/bot/__main__.py b/bot/__main__.py index 19a7e5ec6..f352cd60e 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -39,6 +39,7 @@ bot.load_extension("bot.cogs.logging")  bot.load_extension("bot.cogs.security")  # Commands, etc +bot.load_extension("bot.cogs.antimalware")  bot.load_extension("bot.cogs.antispam")  bot.load_extension("bot.cogs.bot")  bot.load_extension("bot.cogs.clean") diff --git a/bot/cogs/alias.py b/bot/cogs/alias.py index 6648805e9..5190c559b 100644 --- a/bot/cogs/alias.py +++ b/bot/cogs/alias.py @@ -79,10 +79,10 @@ class Alias (Cog):          """Alias for invoking <prefix>site faq."""          await self.invoke(ctx, "site faq") -    @command(name="rules", hidden=True) -    async def site_rules_alias(self, ctx: Context) -> None: +    @command(name="rules", aliases=("rule",), hidden=True) +    async def site_rules_alias(self, ctx: Context, *rules: int) -> None:          """Alias for invoking <prefix>site rules.""" -        await self.invoke(ctx, "site rules") +        await self.invoke(ctx, "site rules", *rules)      @command(name="reload", hidden=True)      async def extensions_reload_alias(self, ctx: Context, *extensions: Extension) -> None: diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py new file mode 100644 index 000000000..ababd6f18 --- /dev/null +++ b/bot/cogs/antimalware.py @@ -0,0 +1,56 @@ +import logging + +from discord import Message, NotFound +from discord.ext.commands import Bot, Cog + +from bot.constants import AntiMalware as AntiMalwareConfig, Channels + +log = logging.getLogger(__name__) + + +class AntiMalware(Cog): +    """Delete messages which contain attachments with non-whitelisted file extensions.""" + +    def __init__(self, bot: Bot): +        self.bot = bot + +    @Cog.listener() +    async def on_message(self, message: Message) -> None: +        """Identify messages with prohibited attachments.""" +        rejected_attachments = False +        detected_pyfile = False +        for attachment in message.attachments: +            if attachment.filename.lower().endswith('.py'): +                detected_pyfile = True +                break  # Other detections irrelevant because we prioritize the .py message. +            if not attachment.filename.lower().endswith(tuple(AntiMalwareConfig.whitelist)): +                rejected_attachments = True + +        if detected_pyfile or rejected_attachments: +            # Send a message to the user indicating the problem (with special treatment for .py) +            author = message.author +            if detected_pyfile: +                msg = ( +                    f"{author.mention}, it looks like you tried to attach a Python file - please " +                    f"use a code-pasting service such as https://paste.pythondiscord.com/ instead." +                ) +            else: +                meta_channel = self.bot.get_channel(Channels.meta) +                msg = ( +                    f"{author.mention}, it looks like you tried to attach a file type we don't " +                    f"allow. Feel free to ask in {meta_channel.mention} if you think this is a mistake." +                ) + +            await message.channel.send(msg) + +            # Delete the offending message: +            try: +                await message.delete() +            except NotFound: +                log.info(f"Tried to delete message `{message.id}`, but message could not be found.") + + +def setup(bot: Bot) -> None: +    """Antimalware cog load.""" +    bot.add_cog(AntiMalware(bot)) +    log.info("Cog loaded: AntiMalware") diff --git a/bot/cogs/moderation/infractions.py b/bot/cogs/moderation/infractions.py index 592ead60f..f2ae7b95d 100644 --- a/bot/cogs/moderation/infractions.py +++ b/bot/cogs/moderation/infractions.py @@ -2,6 +2,7 @@ import logging  import textwrap  import typing as t  from datetime import datetime +from gettext import ngettext  import dateutil.parser  import discord @@ -436,7 +437,13 @@ class Infractions(Scheduler, commands.Cog):          # Default values for the confirmation message and mod log.          confirm_msg = f":ok_hand: applied" -        expiry_msg = f" until {expiry}" if expiry else " permanently" + +        # Specifying an expiry for a note or warning makes no sense. +        if infr_type in ("note", "warning"): +            expiry_msg = "" +        else: +            expiry_msg = f" until {expiry}" if expiry else " permanently" +          dm_result = ""          dm_log_text = ""          expiry_log_text = f"Expires: {expiry}" if expiry else "" @@ -463,7 +470,8 @@ class Infractions(Scheduler, commands.Cog):                  "bot/infractions",                  params={"user__id": str(user.id)}              ) -            end_msg = f" ({len(infractions)} infractions total)" +            total = len(infractions) +            end_msg = f" ({total} infraction{ngettext('', 's', total)} total)"          # Execute the necessary actions to apply the infraction on Discord.          if action_coro: diff --git a/bot/cogs/site.py b/bot/cogs/site.py index c3bdf85e4..d95359159 100644 --- a/bot/cogs/site.py +++ b/bot/cogs/site.py @@ -126,15 +126,15 @@ class Site(Cog):          invalid_indices = tuple(              pick              for pick in rules -            if pick < 0 or pick >= len(full_rules) +            if pick < 1 or pick > len(full_rules)          )          if invalid_indices:              indices = ', '.join(map(str, invalid_indices)) -            await ctx.send(f":x: Invalid rule indices {indices}") +            await ctx.send(f":x: Invalid rule indices: {indices}")              return -        final_rules = tuple(f"**{pick}.** {full_rules[pick]}" for pick in rules) +        final_rules = tuple(f"**{pick}.** {full_rules[pick - 1]}" for pick in rules)          await LinePaginator.paginate(final_rules, ctx, rules_embed, max_lines=3) diff --git a/bot/constants.py b/bot/constants.py index f4f45eb2c..4beae84e9 100644 --- a/bot/constants.py +++ b/bot/constants.py @@ -345,6 +345,7 @@ class Channels(metaclass=YAMLGetter):      help_7: int      helpers: int      message_log: int +    meta: int      mod_alerts: int      modlog: int      off_topic_0: int @@ -460,6 +461,12 @@ class AntiSpam(metaclass=YAMLGetter):      rules: Dict[str, Dict[str, int]] +class AntiMalware(metaclass=YAMLGetter): +    section = "anti_malware" + +    whitelist: list + +  class BigBrother(metaclass=YAMLGetter):      section = 'big_brother' diff --git a/config-default.yml b/config-default.yml index ca405337e..197743296 100644 --- a/config-default.yml +++ b/config-default.yml @@ -107,6 +107,7 @@ guild:          help_7:                           587375768556797982          helpers:                          385474242440986624          message_log:       &MESSAGE_LOG   467752170159079424 +        meta:                             429409067623251969          mod_alerts:                       473092532147060736          modlog:            &MODLOG        282638479504965634          off_topic_0:                      291284109232308226 @@ -322,6 +323,27 @@ anti_spam:              max: 3 +anti_malware: +    whitelist: +        - '.3gp' +        - '.3g2' +        - '.avi' +        - '.bmp' +        - '.gif' +        - '.h264' +        - '.jpg' +        - '.jpeg' +        - '.m4v' +        - '.mkv' +        - '.mov' +        - '.mp4' +        - '.mpeg' +        - '.mpg' +        - '.png' +        - '.tiff' +        - '.wmv' + +  reddit:      request_delay: 60      subreddits: diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 000000000..6ab9bc93e --- /dev/null +++ b/tests/README.md @@ -0,0 +1,213 @@ +# Testing our Bot + +Our bot is one of the most important tools we have for running our community. As we don't want that tool break, we decided that we wanted to write unit tests for it. We hope that in the future, we'll have a 100% test coverage for the bot. This guide will help you get started with writing the tests needed to achieve that. + +_**Note:** This is a practical guide to getting started with writing tests for our bot, not a general introduction to writing unit tests in Python. If you're looking for a more general introduction, you may like Corey Schafer's [Python Tutorial: Unit Testing Your Code with the unittest Module](https://www.youtube.com/watch?v=6tNS--WetLI) or Ned Batchelder's PyCon talk [Getting Started Testing](https://www.youtube.com/watch?v=FxSsnHeWQBY)._ + +## Tools + +We are using the following modules and packages for our unit tests: + +- [unittest](https://docs.python.org/3/library/unittest.html) (standard library) +- [unittest.mock](https://docs.python.org/3/library/unittest.mock.html) (standard library) +- [coverage.py](https://coverage.readthedocs.io/en/stable/) + +To ensure the results you obtain on your personal machine are comparable to those generated in the Azure pipeline, please make sure to run your tests with the virtual environment defined by our [Pipfile](/Pipfile). To run your tests with `pipenv`, we've provided two "scripts" shortcuts: + +- `pipenv run test` will run `unittest` with `coverage.py` +- `pipenv run report` will generate a coverage report of the tests you've run with `pipenv run test`. If you append the `-m` flag to this command, the report will include the lines and branches not covered by tests in addition to the test coverage report. + +If you want a coverage report, make sure to run the tests with `pipenv run test` *first*. + +## Writing tests + +Since consistency is an important consideration for collaborative projects, we have written some guidelines on writing tests for the bot. In addition to these guidelines, it's a good idea to look at the existing code base for examples (e.g., [`test_converters.py`](/tests/bot/test_converters.py)). + +### File and directory structure + +To organize our test suite, we have chosen to mirror the directory structure of [`bot`](/bot/) in the [`tests`](/tests/) subdirectory. This makes it easy to find the relevant tests by providing a natural grouping of files. More general testing files, such as [`helpers.py`](/tests/helpers.py) are located directly in the `tests` subdirectory. + +All files containing tests should have a filename starting with `test_` to make sure `unittest` will discover them. This prefix is typically followed by the name of the file the tests are written for. If needed, a test file can contain multiple test classes, both to provide structure and to be able to provide different fixtures/set-up methods for different groups of tests. + +### Writing independent tests + +When writing unit tests, it's really important to make sure that each test that you write runs independently from all of the other tests. This both means that the code you write for one test shouldn't influence the result of another test and that if one tests fails, the other tests should still run. + +The basis for this is that when you write a test method, it should really only test a single aspect of the thing you're testing. This often means that you do not write one large test that tests "everything" that can be tested for a function, but rather that you write multiple smaller tests that each test a specific branch/path/condition of the function under scrutiny. + +To make sure you're not repeating the same set-up steps in all these smaller tests, `unittest` provides fixtures that are executed before and after each test is run. In addition to test fixtures, it also provides special set-up and clean-up methods that are run before the first test in a test class or after the last test of that class has been run. For more information, see the documentation for [`unittest.TestCase`](https://docs.python.org/3/library/unittest.html#unittest.TestCase). + +#### Method names and docstrings + +As you can probably imagine, writing smaller, independent tests also results in a large number of tests. To make sure that it's easy to see which test does what, it is incredibly important to use good method names to identify what each test is doing. A general guideline is that the name should capture the goal of your test: What is this test method trying to assert? + +In addition to good method names, it's also really important to write a good *single-line* docstring. The `unittest` module will print such a single-line docstring along with the method name in the output it gives when a test fails. This means that a good docstring that really captures the purpose of the test makes it much easier to quickly make sense of output. + +#### Using self.subTest for independent subtests + +Another thing that you will probably encounter is that you want to test a function against a list of input and output values. Given the section on writing independent tests, you may now be tempted to copy-paste the same test method over and over again, once for each unique value that you want to test. However, that would result in a lot of duplicate code that is hard to maintain. + +Luckily, `unittest` provides a good alternative to that: the [`subTest`](https://docs.python.org/3/library/unittest.html#distinguishing-test-iterations-using-subtests) context manager. This method is often used in conjunction with a `for`-loop iterating of a collection of values that we want to test a function against and it provides two important features. First, it will make sure that if an assertion statements fails on one of the iterations, the other iterations are still run. The other important feature it provides is that it will distinguish the iterations from each other in the output. + +This is an example of `TestCase.subTest` in action (taken from [`test_converters.py`](/tests/bot/test_converters.py)): + +```py +    def test_tag_content_converter_for_valid(self): +        """TagContentConverter should return correct values for valid input.""" +        test_values = ( +            ('hello', 'hellpo'), +            ('  h ello  ', 'h ello'), +        ) + +        for content, expected_conversion in test_values: +            with self.subTest(content=content, expected_conversion=expected_conversion): +                conversion = asyncio.run(TagContentConverter.convert(self.context, content)) +                self.assertEqual(conversion, expected_conversion) +``` + +It's important to note the keyword arguments we provide to the `self.subTest` context manager: These keyword arguments and their values will printed in the output when one of the subtests fail, making sure we know *which* subTest failed: + +``` +.................................................................... +====================================================================== +FAIL: test_tag_content_converter_for_valid (tests.bot.test_converters.ConverterTests) (content='hello', expected_conversion='hellpo') +TagContentConverter should return correct values for valid input. +---------------------------------------------------------------------- + +# ... +``` + +## Mocking + +As we are trying to test our "units" of code independently, we want to make sure that we do not rely objects and data generated by "external" code. If we we did, then we wouldn't know if the failure we're observing was caused by the code we are actually trying to test or something external to it. + + +However, the features that we are trying to test often depend on those objects generated by external pieces of code. It would be difficult to test a bot command without having access to a `Context` instance. Fortunately, there's a solution for that: we use fake objects that act like the true object. We call these fake objects "mocks".  + +To create these mock object, we mainly use the [`unittest.mock`](https://docs.python.org/3/library/unittest.mock.html) module. In addition, we have also defined a couple of specialized mock objects that mock specific `discord.py` types (see the section on the below.). + +An example of mocking is when we provide a command with a mocked version of `discord.ext.commands.Context` object instead of a real `Context` object. This makes sure we can then check (_assert_) if the `send` method of the mocked Context object was called with the correct message content (without having to send a real message to the Discord API!): + +```py +import asyncio +import unittest + +from bot.cogs import bot +from tests.helpers import MockBot, MockContext + + +class BotCogTests(unittest.TestCase): +    def test_echo_command_correctly_echoes_arguments(self): +        """Test if the `!echo <text>` command correctly echoes the content.""" +        mocked_bot = MockBot() +        bot_cog = bot.Bot(mocked_bot) + +        mocked_context = MockContext() + +        text = "Hello! This should be echoed!" + +        asyncio.run(bot_cog.echo_command.callback(bot_cog, mocked_context, text=text)) + +        mocked_context.send.assert_called_with(text) +``` + +### Mocking coroutines + +By default, the `unittest.mock.Mock` and `unittest.mock.MagicMock` classes cannot mock coroutines, since the `__call__` method they provide is synchronous. In anticipation of the `AsyncMock` that will be [introduced in Python 3.8](https://docs.python.org/3.9/whatsnew/3.8.html#unittest), we have added an `AsyncMock` helper to [`helpers.py`](/tests/helpers.py). Do note that this drop-in replacement only implements an asynchronous `__call__` method, not the additional assertions that will come with the new `AsyncMock` type in Python 3.8.  + +### Special mocks for some `discord.py` types + +To quote Ned Batchelder, Mock objects are "automatic chameleons". This means that they will happily allow the access to any attribute or method and provide a mocked value in return. One downside to this is that if the code you are testing gets the name of the attribute wrong, your mock object will not complain and the test may still pass. + +In order to avoid that, we have defined a number of Mock types in [`helpers.py`](/tests/helpers.py) that follow the specifications of the actual Discord types they are mocking. This means that trying to access an attribute or method on a mocked object that does not exist on the equivalent `discord.py` object will result in an `AttributeError`. In addition, these mocks have some sensible defaults and **pass `isinstance` checks for the types they are mocking**.  + +These special mocks are added when they are needed, so if you think it would be sensible to add another one, feel free to propose one in your PR. + +**Note:** These mock types only "know" the attributes that are set by default when these `discord.py` types are first initialized. If you need to work with dynamically set attributes that are added after initialization, you can still explicitly mock them: + +```py +import unittest.mock +from tests.helpers import MockGuild + +guild = MockGuild() +guild.some_attribute = unittest.mock.MagicMock() +``` + +The attribute `some_attribute` will now be accessible as a `MagicMock` on the mocked object. + +--- + +## Some considerations + +Finally, there are some considerations to make when writing tests, both for writing tests in general and for writing tests for our bot in particular. + +### Test coverage is a starting point + +Having test coverage is a good starting point for unit testing: If a part of your code was not covered by a test, we know that we have not tested it properly. The reverse is unfortunately not true: Even if the code we are testing has 100% branch coverage, it does not mean it's fully tested or guaranteed to work.  + +One problem is that 100% branch coverage may be misleading if we haven't tested our code against all the realistic input it may get in production. For instance, take a look at the following `member_information` function and the test we've written for it: + +```py +import datetime +import unittest +import unittest.mock + + +def member_information(member): +    joined = member.joined.stfptime("%d-%m-%Y") if member.joined else "unknown" +    return f"{member.name} (joined: {joined})" + + +class FunctionsTests(unittest.TestCase): +    def test_member_information(self): +        member = unittest.mock.Mock() +        member.name = "lemon" +        member.joined = None +        self.assertEqual(member_information(member), "lemon (joined: unknown)") +``` + +If you were to run this test, not only would the function pass the test, `coverage.py` will also tell us that the test provides 100% branch coverage for the function. Can you spot the bug the test suite did not catch? + +The problem here is that we have only tested our function with a member object that had `None` for the `member.joined` attribute. This means that `member.joined.stfptime("%d-%m-%Y")` was never executed during our test, leading to us missing the spelling mistake in `stfptime` (it should be `strftime`).  + +Adding another test would not increase the test coverage we have, but it does ensure that we'll notice that this function can fail with realistic data: + +```py +# (...) +class FunctionsTests(unittest.TestCase): +    # (...) +    def test_member_information_with_join_datetime(self): +        member = unittest.mock.Mock() +        member.name = "lemon" +        member.joined = datetime.datetime(year=2019, month=10, day=10) +        self.assertEqual(member_information(member), "lemon (joined: 10-10-2019)") +``` + +Output: +``` +.E +====================================================================== +ERROR: test_member_information_with_join_datetime (tests.test_functions.FunctionsTests) +---------------------------------------------------------------------- +Traceback (most recent call last): +  File "/home/pydis/playground/tests/test_functions.py", line 23, in test_member_information_with_join_datetime +    self.assertEqual(member_information(member), "lemon (joined: 10-10-2019)") +  File "/home/pydis/playground/tests/test_functions.py", line 8, in member_information +    joined = member.joined.stfptime("%d-%m-%Y") if member.joined else "unknown" +AttributeError: 'datetime.datetime' object has no attribute 'stfptime' + +---------------------------------------------------------------------- +Ran 2 tests in 0.003s + +FAILED (errors=1) +``` + +What's more, even if the spelling mistake would not have been there, the first test did not test if the `member_information` function formatted the `member.join` according to the output we actually want to see. + +All in all, it's not only important to consider if all statements or branches were touched at least once with a test, but also if they are extensively tested in all situations that may happen in production. + +### Unit Testing vs Integration Testing + +Another restriction of unit testing is that it tests, well, in units. Even if we can guarantee that the units work as they should independently, we have no guarantee that they will actually work well together. Even more, while the mocking described above gives us a lot of flexibility in factoring out external code, we are work under the implicit assumption that we fully understand those external parts and utilize it correctly. What if our mocked `Context` object works with a `send` method, but `discord.py` has changed it to a `send_message` method in a recent update? It could mean our tests are passing, but the code it's testing still doesn't work in production. + +The answer to this is that we also need to make sure that the individual parts come together into a working application. In addition, we will also need to make sure that the application communicates correctly with external applications. Since we currently have no automated integration tests or functional tests, that means **it's still very important to fire up the bot and test the code you've written manually** in addition to the unit tests you've written. diff --git a/tests/__init__.py b/tests/__init__.py index e69de29bb..2228110ad 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,5 @@ +import logging + + +log = logging.getLogger() +log.setLevel(logging.CRITICAL) diff --git a/tests/base.py b/tests/base.py new file mode 100644 index 000000000..029a249ed --- /dev/null +++ b/tests/base.py @@ -0,0 +1,67 @@ +import logging +import unittest +from contextlib import contextmanager + + +class _CaptureLogHandler(logging.Handler): +    """ +    A logging handler capturing all (raw and formatted) logging output. +    """ + +    def __init__(self): +        super().__init__() +        self.records = [] + +    def emit(self, record): +        self.records.append(record) + + +class LoggingTestCase(unittest.TestCase): +    """TestCase subclass that adds more logging assertion tools.""" + +    @contextmanager +    def assertNotLogs(self, logger=None, level=None, msg=None): +        """ +        Asserts that no logs of `level` and higher were emitted by `logger`. + +        You can specify a specific `logger`, the minimum `logging` level we want to watch and a +        custom `msg` to be added to the `AssertionError` if thrown. If the assertion fails, the +        recorded log records will be outputted with the `AssertionError` message. The context +        manager does not yield a live `look` into the logging records, since we use this context +        manager when we're testing under the assumption that no log records will be emitted. +        """ +        if not isinstance(logger, logging.Logger): +            logger = logging.getLogger(logger) + +        if level: +            level = logging._nameToLevel.get(level, level) +        else: +            level = logging.INFO + +        handler = _CaptureLogHandler() +        old_handlers = logger.handlers[:] +        old_level = logger.level +        old_propagate = logger.propagate + +        logger.handlers = [handler] +        logger.setLevel(level) +        logger.propagate = False + +        try: +            yield +        except Exception as exc: +            raise exc +        finally: +            logger.handlers = old_handlers +            logger.propagate = old_propagate +            logger.setLevel(old_level) + +        if handler.records: +            level_name = logging.getLevelName(level) +            n_logs = len(handler.records) +            base_message = f"{n_logs} logs of {level_name} or higher were triggered on {logger.name}:\n" +            records = [str(record) for record in handler.records] +            record_message = "\n".join(records) +            standard_message = self._truncateMessage(base_message, record_message) +            msg = self._formatMessage(msg, standard_message) +            self.fail(msg) diff --git a/tests/cogs/__init__.py b/tests/bot/__init__.py index e69de29bb..e69de29bb 100644 --- a/tests/cogs/__init__.py +++ b/tests/bot/__init__.py diff --git a/tests/cogs/sync/__init__.py b/tests/bot/cogs/__init__.py index e69de29bb..e69de29bb 100644 --- a/tests/cogs/sync/__init__.py +++ b/tests/bot/cogs/__init__.py diff --git a/tests/rules/__init__.py b/tests/bot/cogs/sync/__init__.py index e69de29bb..e69de29bb 100644 --- a/tests/rules/__init__.py +++ b/tests/bot/cogs/sync/__init__.py diff --git a/tests/bot/cogs/sync/test_roles.py b/tests/bot/cogs/sync/test_roles.py new file mode 100644 index 000000000..27ae27639 --- /dev/null +++ b/tests/bot/cogs/sync/test_roles.py @@ -0,0 +1,126 @@ +import unittest + +from bot.cogs.sync.syncers import Role, get_roles_for_sync + + +class GetRolesForSyncTests(unittest.TestCase): +    """Tests constructing the roles to synchronize with the site.""" + +    def test_get_roles_for_sync_empty_return_for_equal_roles(self): +        """No roles should be synced when no diff is found.""" +        api_roles = {Role(id=41, name='name', colour=33, permissions=0x8, position=1)} +        guild_roles = {Role(id=41, name='name', colour=33, permissions=0x8, position=1)} + +        self.assertEqual( +            get_roles_for_sync(guild_roles, api_roles), +            (set(), set(), set()) +        ) + +    def test_get_roles_for_sync_returns_roles_to_update_with_non_id_diff(self): +        """Roles to be synced are returned when non-ID attributes differ.""" +        api_roles = {Role(id=41, name='old name', colour=35, permissions=0x8, position=1)} +        guild_roles = {Role(id=41, name='new name', colour=33, permissions=0x8, position=2)} + +        self.assertEqual( +            get_roles_for_sync(guild_roles, api_roles), +            (set(), guild_roles, set()) +        ) + +    def test_get_roles_only_returns_roles_that_require_update(self): +        """Roles that require an update should be returned as the second tuple element.""" +        api_roles = { +            Role(id=41, name='old name', colour=33, permissions=0x8, position=1), +            Role(id=53, name='other role', colour=55, permissions=0, position=3) +        } +        guild_roles = { +            Role(id=41, name='new name', colour=35, permissions=0x8, position=2), +            Role(id=53, name='other role', colour=55, permissions=0, position=3) +        } + +        self.assertEqual( +            get_roles_for_sync(guild_roles, api_roles), +            ( +                set(), +                {Role(id=41, name='new name', colour=35, permissions=0x8, position=2)}, +                set(), +            ) +        ) + +    def test_get_roles_returns_new_roles_in_first_tuple_element(self): +        """Newly created roles are returned as the first tuple element.""" +        api_roles = { +            Role(id=41, name='name', colour=35, permissions=0x8, position=1), +        } +        guild_roles = { +            Role(id=41, name='name', colour=35, permissions=0x8, position=1), +            Role(id=53, name='other role', colour=55, permissions=0, position=2) +        } + +        self.assertEqual( +            get_roles_for_sync(guild_roles, api_roles), +            ( +                {Role(id=53, name='other role', colour=55, permissions=0, position=2)}, +                set(), +                set(), +            ) +        ) + +    def test_get_roles_returns_roles_to_update_and_new_roles(self): +        """Newly created and updated roles should be returned together.""" +        api_roles = { +            Role(id=41, name='old name', colour=35, permissions=0x8, position=1), +        } +        guild_roles = { +            Role(id=41, name='new name', colour=40, permissions=0x16, position=2), +            Role(id=53, name='other role', colour=55, permissions=0, position=3) +        } + +        self.assertEqual( +            get_roles_for_sync(guild_roles, api_roles), +            ( +                {Role(id=53, name='other role', colour=55, permissions=0, position=3)}, +                {Role(id=41, name='new name', colour=40, permissions=0x16, position=2)}, +                set(), +            ) +        ) + +    def test_get_roles_returns_roles_to_delete(self): +        """Roles to be deleted should be returned as the third tuple element.""" +        api_roles = { +            Role(id=41, name='name', colour=35, permissions=0x8, position=1), +            Role(id=61, name='to delete', colour=99, permissions=0x9, position=2), +        } +        guild_roles = { +            Role(id=41, name='name', colour=35, permissions=0x8, position=1), +        } + +        self.assertEqual( +            get_roles_for_sync(guild_roles, api_roles), +            ( +                set(), +                set(), +                {Role(id=61, name='to delete', colour=99, permissions=0x9, position=2)}, +            ) +        ) + +    def test_get_roles_returns_roles_to_delete_update_and_new_roles(self): +        """When roles were added, updated, and removed, all of them are returned properly.""" +        api_roles = { +            Role(id=41, name='not changed', colour=35, permissions=0x8, position=1), +            Role(id=61, name='to delete', colour=99, permissions=0x9, position=2), +            Role(id=71, name='to update', colour=99, permissions=0x9, position=3), +        } +        guild_roles = { +            Role(id=41, name='not changed', colour=35, permissions=0x8, position=1), +            Role(id=81, name='to create', colour=99, permissions=0x9, position=4), +            Role(id=71, name='updated', colour=101, permissions=0x5, position=3), +        } + +        self.assertEqual( +            get_roles_for_sync(guild_roles, api_roles), +            ( +                {Role(id=81, name='to create', colour=99, permissions=0x9, position=4)}, +                {Role(id=71, name='updated', colour=101, permissions=0x5, position=3)}, +                {Role(id=61, name='to delete', colour=99, permissions=0x9, position=2)}, +            ) +        ) diff --git a/tests/bot/cogs/sync/test_users.py b/tests/bot/cogs/sync/test_users.py new file mode 100644 index 000000000..ccaf67490 --- /dev/null +++ b/tests/bot/cogs/sync/test_users.py @@ -0,0 +1,84 @@ +import unittest + +from bot.cogs.sync.syncers import User, get_users_for_sync + + +def fake_user(**kwargs): +    kwargs.setdefault('id', 43) +    kwargs.setdefault('name', 'bob the test man') +    kwargs.setdefault('discriminator', 1337) +    kwargs.setdefault('avatar_hash', None) +    kwargs.setdefault('roles', (666,)) +    kwargs.setdefault('in_guild', True) +    return User(**kwargs) + + +class GetUsersForSyncTests(unittest.TestCase): +    """Tests constructing the users to synchronize with the site.""" + +    def test_get_users_for_sync_returns_nothing_for_empty_params(self): +        """When no users are given, none are returned.""" +        self.assertEqual( +            get_users_for_sync({}, {}), +            (set(), set()) +        ) + +    def test_get_users_for_sync_returns_nothing_for_equal_users(self): +        """When no users are updated, none are returned.""" +        api_users = {43: fake_user()} +        guild_users = {43: fake_user()} + +        self.assertEqual( +            get_users_for_sync(guild_users, api_users), +            (set(), set()) +        ) + +    def test_get_users_for_sync_returns_users_to_update_on_non_id_field_diff(self): +        """When a non-ID-field differs, the user to update is returned.""" +        api_users = {43: fake_user()} +        guild_users = {43: fake_user(name='new fancy name')} + +        self.assertEqual( +            get_users_for_sync(guild_users, api_users), +            (set(), {fake_user(name='new fancy name')}) +        ) + +    def test_get_users_for_sync_returns_users_to_create_with_new_ids_on_guild(self): +        """When new users join the guild, they are returned as the first tuple element.""" +        api_users = {43: fake_user()} +        guild_users = {43: fake_user(), 63: fake_user(id=63)} + +        self.assertEqual( +            get_users_for_sync(guild_users, api_users), +            ({fake_user(id=63)}, set()) +        ) + +    def test_get_users_for_sync_updates_in_guild_field_on_user_leave(self): +        """When a user leaves the guild, the `in_guild` flag is updated to `False`.""" +        api_users = {43: fake_user(), 63: fake_user(id=63)} +        guild_users = {43: fake_user()} + +        self.assertEqual( +            get_users_for_sync(guild_users, api_users), +            (set(), {fake_user(id=63, in_guild=False)}) +        ) + +    def test_get_users_for_sync_updates_and_creates_users_as_needed(self): +        """When one user left and another one was updated, both are returned.""" +        api_users = {43: fake_user()} +        guild_users = {63: fake_user(id=63)} + +        self.assertEqual( +            get_users_for_sync(guild_users, api_users), +            ({fake_user(id=63)}, {fake_user(in_guild=False)}) +        ) + +    def test_get_users_for_sync_does_not_duplicate_update_users(self): +        """When the API knows a user the guild doesn't, nothing is performed.""" +        api_users = {43: fake_user(in_guild=False)} +        guild_users = {} + +        self.assertEqual( +            get_users_for_sync(guild_users, api_users), +            (set(), set()) +        ) diff --git a/tests/bot/cogs/test_antispam.py b/tests/bot/cogs/test_antispam.py new file mode 100644 index 000000000..ce5472c71 --- /dev/null +++ b/tests/bot/cogs/test_antispam.py @@ -0,0 +1,35 @@ +import unittest + +from bot.cogs import antispam + + +class AntispamConfigurationValidationTests(unittest.TestCase): +    """Tests validation of the antispam cog configuration.""" + +    def test_default_antispam_config_is_valid(self): +        """The default antispam configuration is valid.""" +        validation_errors = antispam.validate_config() +        self.assertEqual(validation_errors, {}) + +    def test_unknown_rule_returns_error(self): +        """Configuring an unknown rule returns an error.""" +        self.assertEqual( +            antispam.validate_config({'invalid-rule': {}}), +            {'invalid-rule': "`invalid-rule` is not recognized as an antispam rule."} +        ) + +    def test_missing_keys_returns_error(self): +        """Not configuring required keys returns an error.""" +        keys = (('interval', 'max'), ('max', 'interval')) +        for configured_key, unconfigured_key in keys: +            with self.subTest( +                configured_key=configured_key, +                unconfigured_key=unconfigured_key +            ): +                config = {'burst': {configured_key: 10}} +                error = f"Key `{unconfigured_key}` is required but not set for rule `burst`" + +                self.assertEqual( +                    antispam.validate_config(config), +                    {'burst': error} +                ) diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py new file mode 100644 index 000000000..9bbd35a91 --- /dev/null +++ b/tests/bot/cogs/test_information.py @@ -0,0 +1,164 @@ +import asyncio +import textwrap +import unittest +import unittest.mock + +import discord + +from bot import constants +from bot.cogs import information +from tests.helpers import AsyncMock, MockBot, MockContext, MockGuild, MockMember, MockRole + + +class InformationCogTests(unittest.TestCase): +    """Tests the Information cog.""" + +    @classmethod +    def setUpClass(cls): +        cls.moderator_role = MockRole(name="Moderator", role_id=constants.Roles.moderator) + +    def setUp(self): +        """Sets up fresh objects for each test.""" +        self.bot = MockBot() + +        self.cog = information.Information(self.bot) + +        self.ctx = MockContext() +        self.ctx.author.roles.append(self.moderator_role) + +    def test_roles_command_command(self): +        """Test if the `role_info` command correctly returns the `moderator_role`.""" +        self.ctx.guild.roles.append(self.moderator_role) + +        self.cog.roles_info.can_run = AsyncMock() +        self.cog.roles_info.can_run.return_value = True + +        coroutine = self.cog.roles_info.callback(self.cog, self.ctx) + +        self.assertIsNone(asyncio.run(coroutine)) +        self.ctx.send.assert_called_once() + +        _, kwargs = self.ctx.send.call_args +        embed = kwargs.pop('embed') + +        self.assertEqual(embed.title, "Role information") +        self.assertEqual(embed.colour, discord.Colour.blurple()) +        self.assertEqual(embed.description, f"`{self.moderator_role.id}` - {self.moderator_role.mention}\n") +        self.assertEqual(embed.footer.text, "Total roles: 1") + +    def test_role_info_command(self): +        """Tests the `role info` command.""" +        dummy_role = MockRole( +            name="Dummy", +            role_id=112233445566778899, +            colour=discord.Colour.blurple(), +            position=10, +            members=[self.ctx.author], +            permissions=discord.Permissions(0) +        ) + +        admin_role = MockRole( +            name="Admins", +            role_id=998877665544332211, +            colour=discord.Colour.red(), +            position=3, +            members=[self.ctx.author], +            permissions=discord.Permissions(0), +        ) + +        self.ctx.guild.roles.append([dummy_role, admin_role]) + +        self.cog.role_info.can_run = AsyncMock() +        self.cog.role_info.can_run.return_value = True + +        coroutine = self.cog.role_info.callback(self.cog, self.ctx, dummy_role, admin_role) + +        self.assertIsNone(asyncio.run(coroutine)) + +        self.assertEqual(self.ctx.send.call_count, 2) + +        (_, dummy_kwargs), (_, admin_kwargs) = self.ctx.send.call_args_list + +        dummy_embed = dummy_kwargs["embed"] +        admin_embed = admin_kwargs["embed"] + +        self.assertEqual(dummy_embed.title, "Dummy info") +        self.assertEqual(dummy_embed.colour, discord.Colour.blurple()) + +        self.assertEqual(dummy_embed.fields[0].value, str(dummy_role.id)) +        self.assertEqual(dummy_embed.fields[1].value, f"#{dummy_role.colour.value:0>6x}") +        self.assertEqual(dummy_embed.fields[2].value, "0.63 0.48 218") +        self.assertEqual(dummy_embed.fields[3].value, "1") +        self.assertEqual(dummy_embed.fields[4].value, "10") +        self.assertEqual(dummy_embed.fields[5].value, "0") + +        self.assertEqual(admin_embed.title, "Admins info") +        self.assertEqual(admin_embed.colour, discord.Colour.red()) + +    @unittest.mock.patch('bot.cogs.information.time_since') +    def test_server_info_command(self, time_since_patch): +        time_since_patch.return_value = '2 days ago' + +        self.ctx.guild = MockGuild( +            features=('lemons', 'apples'), +            region="The Moon", +            roles=[self.moderator_role], +            channels=[ +                discord.TextChannel( +                    state={}, +                    guild=self.ctx.guild, +                    data={'id': 42, 'name': 'lemons-offering', 'position': 22, 'type': 'text'} +                ), +                discord.CategoryChannel( +                    state={}, +                    guild=self.ctx.guild, +                    data={'id': 5125, 'name': 'the-lemon-collection', 'position': 22, 'type': 'category'} +                ), +                discord.VoiceChannel( +                    state={}, +                    guild=self.ctx.guild, +                    data={'id': 15290, 'name': 'listen-to-lemon', 'position': 22, 'type': 'voice'} +                ) +            ], +            members=[ +                *(MockMember(status='online') for _ in range(2)), +                *(MockMember(status='idle') for _ in range(1)), +                *(MockMember(status='dnd') for _ in range(4)), +                *(MockMember(status='offline') for _ in range(3)), +            ], +            member_count=1_234, +            icon_url='a-lemon.jpg', +        ) + +        coroutine = self.cog.server_info.callback(self.cog, self.ctx) +        self.assertIsNone(asyncio.run(coroutine)) + +        time_since_patch.assert_called_once_with(self.ctx.guild.created_at, precision='days') +        _, kwargs = self.ctx.send.call_args +        embed = kwargs.pop('embed') +        self.assertEqual(embed.colour, discord.Colour.blurple()) +        self.assertEqual( +            embed.description, +            textwrap.dedent( +                f""" +                **Server information** +                Created: {time_since_patch.return_value} +                Voice region: {self.ctx.guild.region} +                Features: {', '.join(self.ctx.guild.features)} + +                **Counts** +                Members: {self.ctx.guild.member_count:,} +                Roles: {len(self.ctx.guild.roles)} +                Text: 1 +                Voice: 1 +                Channel categories: 1 + +                **Members** +                {constants.Emojis.status_online} 2 +                {constants.Emojis.status_idle} 1 +                {constants.Emojis.status_dnd} 4 +                {constants.Emojis.status_offline} 3 +                """ +            ) +        ) +        self.assertEqual(embed.thumbnail.url, 'a-lemon.jpg') diff --git a/tests/bot/cogs/test_security.py b/tests/bot/cogs/test_security.py new file mode 100644 index 000000000..efa7a50b1 --- /dev/null +++ b/tests/bot/cogs/test_security.py @@ -0,0 +1,59 @@ +import logging +import unittest +from unittest.mock import MagicMock + +from discord.ext.commands import NoPrivateMessage + +from bot.cogs import security +from tests.helpers import MockBot, MockContext + + +class SecurityCogTests(unittest.TestCase): +    """Tests the `Security` cog.""" + +    def setUp(self): +        """Attach an instance of the cog to the class for tests.""" +        self.bot = MockBot() +        self.cog = security.Security(self.bot) +        self.ctx = MockContext() + +    def test_check_additions(self): +        """The cog should add its checks after initialization.""" +        self.bot.check.assert_any_call(self.cog.check_on_guild) +        self.bot.check.assert_any_call(self.cog.check_not_bot) + +    def test_check_not_bot_returns_false_for_humans(self): +        """The bot check should return `True` when invoked with human authors.""" +        self.ctx.author.bot = False +        self.assertTrue(self.cog.check_not_bot(self.ctx)) + +    def test_check_not_bot_returns_true_for_robots(self): +        """The bot check should return `False` when invoked with robotic authors.""" +        self.ctx.author.bot = True +        self.assertFalse(self.cog.check_not_bot(self.ctx)) + +    def test_check_on_guild_raises_when_outside_of_guild(self): +        """When invoked outside of a guild, `check_on_guild` should cause an error.""" +        self.ctx.guild = None + +        with self.assertRaises(NoPrivateMessage, msg="This command cannot be used in private messages."): +            self.cog.check_on_guild(self.ctx) + +    def test_check_on_guild_returns_true_inside_of_guild(self): +        """When invoked inside of a guild, `check_on_guild` should return `True`.""" +        self.ctx.guild = "lemon's lemonade stand" +        self.assertTrue(self.cog.check_on_guild(self.ctx)) + + +class SecurityCogLoadTests(unittest.TestCase): +    """Tests loading the `Security` cog.""" + +    def test_security_cog_load(self): +        """Cog loading logs a message at `INFO` level.""" +        bot = MagicMock() +        with self.assertLogs(logger='bot.cogs.security', level=logging.INFO) as cm: +            security.setup(bot) +            bot.add_cog.assert_called_once() + +        [line] = cm.output +        self.assertIn("Cog loaded: Security", line) diff --git a/tests/bot/cogs/test_token_remover.py b/tests/bot/cogs/test_token_remover.py new file mode 100644 index 000000000..dfb1bafc9 --- /dev/null +++ b/tests/bot/cogs/test_token_remover.py @@ -0,0 +1,135 @@ +import asyncio +import logging +import unittest +from unittest.mock import MagicMock + +from discord import Colour + +from bot.cogs.token_remover import ( +    DELETION_MESSAGE_TEMPLATE, +    TokenRemover, +    setup as setup_cog, +) +from bot.constants import Channels, Colours, Event, Icons +from tests.helpers import AsyncMock, MockBot, MockMessage + + +class TokenRemoverTests(unittest.TestCase): +    """Tests the `TokenRemover` cog.""" + +    def setUp(self): +        """Adds the cog, a bot, and a message to the instance for usage in tests.""" +        self.bot = MockBot() +        self.bot.get_cog.return_value = MagicMock() +        self.bot.get_cog.return_value.send_log_message = AsyncMock() +        self.cog = TokenRemover(bot=self.bot) + +        self.msg = MockMessage(message_id=555, content='') +        self.msg.author.__str__ = MagicMock() +        self.msg.author.__str__.return_value = 'lemon' +        self.msg.author.bot = False +        self.msg.author.avatar_url_as.return_value = 'picture-lemon.png' +        self.msg.author.id = 42 +        self.msg.author.mention = '@lemon' +        self.msg.channel.mention = "#lemonade-stand" + +    def test_is_valid_user_id_is_true_for_numeric_content(self): +        """A string decoding to numeric characters is a valid user ID.""" +        # MTIz = base64(123) +        self.assertTrue(TokenRemover.is_valid_user_id('MTIz')) + +    def test_is_valid_user_id_is_false_for_alphabetic_content(self): +        """A string decoding to alphabetic characters is not a valid user ID.""" +        # YWJj = base64(abc) +        self.assertFalse(TokenRemover.is_valid_user_id('YWJj')) + +    def test_is_valid_timestamp_is_true_for_valid_timestamps(self): +        """A string decoding to a valid timestamp should be recognized as such.""" +        self.assertTrue(TokenRemover.is_valid_timestamp('DN9r_A')) + +    def test_is_valid_timestamp_is_false_for_invalid_values(self): +        """A string not decoding to a valid timestamp should not be recognized as such.""" +        # MTIz = base64(123) +        self.assertFalse(TokenRemover.is_valid_timestamp('MTIz')) + +    def test_mod_log_property(self): +        """The `mod_log` property should ask the bot to return the `ModLog` cog.""" +        self.bot.get_cog.return_value = 'lemon' +        self.assertEqual(self.cog.mod_log, self.bot.get_cog.return_value) +        self.bot.get_cog.assert_called_once_with('ModLog') + +    def test_ignores_bot_messages(self): +        """When the message event handler is called with a bot message, nothing is done.""" +        self.msg.author.bot = True +        coroutine = self.cog.on_message(self.msg) +        self.assertIsNone(asyncio.run(coroutine)) + +    def test_ignores_messages_without_tokens(self): +        """Messages without anything looking like a token are ignored.""" +        for content in ('', 'lemon wins'): +            with self.subTest(content=content): +                self.msg.content = content +                coroutine = self.cog.on_message(self.msg) +                self.assertIsNone(asyncio.run(coroutine)) + +    def test_ignores_messages_with_invalid_tokens(self): +        """Messages with values that are invalid tokens are ignored.""" +        for content in ('foo.bar.baz', 'x.y.'): +            with self.subTest(content=content): +                self.msg.content = content +                coroutine = self.cog.on_message(self.msg) +                self.assertIsNone(asyncio.run(coroutine)) + +    def test_censors_valid_tokens(self): +        """Valid tokens are censored.""" +        cases = ( +            # (content, censored_token) +            ('MTIz.DN9R_A.xyz', 'MTIz.DN9R_A.xxx'), +        ) + +        for content, censored_token in cases: +            with self.subTest(content=content, censored_token=censored_token): +                self.msg.content = content +                coroutine = self.cog.on_message(self.msg) +                with self.assertLogs(logger='bot.cogs.token_remover', level=logging.DEBUG) as cm: +                    self.assertIsNone(asyncio.run(coroutine))  # no return value + +                [line] = cm.output +                log_message = ( +                    "Censored a seemingly valid token sent by " +                    "lemon (`42`) in #lemonade-stand, " +                    f"token was `{censored_token}`" +                ) +                self.assertIn(log_message, line) + +                self.msg.delete.assert_called_once_with() +                self.msg.channel.send.assert_called_once_with( +                    DELETION_MESSAGE_TEMPLATE.format(mention='@lemon') +                ) +                self.bot.get_cog.assert_called_with('ModLog') +                self.msg.author.avatar_url_as.assert_called_once_with(static_format='png') + +                mod_log = self.bot.get_cog.return_value +                mod_log.ignore.assert_called_once_with(Event.message_delete, self.msg.id) +                mod_log.send_log_message.assert_called_once_with( +                    icon_url=Icons.token_removed, +                    colour=Colour(Colours.soft_red), +                    title="Token removed!", +                    text=log_message, +                    thumbnail='picture-lemon.png', +                    channel_id=Channels.mod_alerts +                ) + + +class TokenRemoverSetupTests(unittest.TestCase): +    """Tests setup of the `TokenRemover` cog.""" + +    def test_setup(self): +        """Setup of the cog should log a message at `INFO` level.""" +        bot = MockBot() +        with self.assertLogs(logger='bot.cogs.token_remover', level=logging.INFO) as cm: +            setup_cog(bot) + +        [line] = cm.output +        bot.add_cog.assert_called_once() +        self.assertIn("Cog loaded: TokenRemover", line) diff --git a/tests/utils/__init__.py b/tests/bot/patches/__init__.py index e69de29bb..e69de29bb 100644 --- a/tests/utils/__init__.py +++ b/tests/bot/patches/__init__.py diff --git a/tests/bot/resources/__init__.py b/tests/bot/resources/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/tests/bot/resources/__init__.py diff --git a/tests/bot/resources/test_resources.py b/tests/bot/resources/test_resources.py new file mode 100644 index 000000000..73937cfa6 --- /dev/null +++ b/tests/bot/resources/test_resources.py @@ -0,0 +1,17 @@ +import json +import unittest +from pathlib import Path + + +class ResourceValidationTests(unittest.TestCase): +    """Validates resources used by the bot.""" +    def test_stars_valid(self): +        """The resource `bot/resources/stars.json` should contain a list of strings.""" +        path = Path('bot', 'resources', 'stars.json') +        content = path.read_text() +        data = json.loads(content) + +        self.assertIsInstance(data, list) +        for name in data: +            with self.subTest(name=name): +                self.assertIsInstance(name, str) diff --git a/tests/bot/rules/__init__.py b/tests/bot/rules/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/tests/bot/rules/__init__.py diff --git a/tests/bot/rules/test_attachments.py b/tests/bot/rules/test_attachments.py new file mode 100644 index 000000000..4bb0acf7c --- /dev/null +++ b/tests/bot/rules/test_attachments.py @@ -0,0 +1,52 @@ +import asyncio +import unittest +from dataclasses import dataclass +from typing import Any, List + +from bot.rules import attachments + + +# Using `MagicMock` sadly doesn't work for this usecase +# since it's __eq__ compares the MagicMock's ID. We just +# want to compare the actual attributes we set. +@dataclass +class FakeMessage: +    author: str +    attachments: List[Any] + + +def msg(total_attachments: int) -> FakeMessage: +    return FakeMessage(author='lemon', attachments=list(range(total_attachments))) + + +class AttachmentRuleTests(unittest.TestCase): +    """Tests applying the `attachment` antispam rule.""" + +    def test_allows_messages_without_too_many_attachments(self): +        """Messages without too many attachments are allowed as-is.""" +        cases = ( +            (msg(0), msg(0), msg(0)), +            (msg(2), msg(2)), +            (msg(0),), +        ) + +        for last_message, *recent_messages in cases: +            with self.subTest(last_message=last_message, recent_messages=recent_messages): +                coro = attachments.apply(last_message, recent_messages, {'max': 5}) +                self.assertIsNone(asyncio.run(coro)) + +    def test_disallows_messages_with_too_many_attachments(self): +        """Messages with too many attachments trigger the rule.""" +        cases = ( +            ((msg(4), msg(0), msg(6)), [msg(4), msg(6)], 10), +            ((msg(6),), [msg(6)], 6), +            ((msg(1),) * 6, [msg(1)] * 6, 6), +        ) +        for messages, relevant_messages, total in cases: +            with self.subTest(messages=messages, relevant_messages=relevant_messages, total=total): +                last_message, *recent_messages = messages +                coro = attachments.apply(last_message, recent_messages, {'max': 5}) +                self.assertEqual( +                    asyncio.run(coro), +                    (f"sent {total} attachments in 5s", ('lemon',), relevant_messages) +                ) diff --git a/tests/bot/test_api.py b/tests/bot/test_api.py new file mode 100644 index 000000000..e0ede0eb1 --- /dev/null +++ b/tests/bot/test_api.py @@ -0,0 +1,134 @@ +import logging +import unittest +from unittest.mock import MagicMock, patch + +from bot import api +from tests.base import LoggingTestCase +from tests.helpers import async_test + + +class APIClientTests(unittest.TestCase): +    """Tests for the bot's API client.""" + +    @classmethod +    def setUpClass(cls): +        """Sets up the shared fixtures for the tests.""" +        cls.error_api_response = MagicMock() +        cls.error_api_response.status = 999 + +    def test_loop_is_not_running_by_default(self): +        """The event loop should not be running by default.""" +        self.assertFalse(api.loop_is_running()) + +    @async_test +    async def test_loop_is_running_in_async_context(self): +        """The event loop should be running in an async context.""" +        self.assertTrue(api.loop_is_running()) + +    def test_response_code_error_default_initialization(self): +        """Test the default initialization of `ResponseCodeError` without `text` or `json`""" +        error = api.ResponseCodeError(response=self.error_api_response) + +        self.assertIs(error.status, self.error_api_response.status) +        self.assertEqual(error.response_json, {}) +        self.assertEqual(error.response_text, "") +        self.assertIs(error.response, self.error_api_response) + +    def test_responde_code_error_string_representation_default_initialization(self): +        """Test the string representation of `ResponseCodeError` initialized without text or json.""" +        error = api.ResponseCodeError(response=self.error_api_response) +        self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: ") + +    def test_response_code_error_initialization_with_json(self): +        """Test the initialization of `ResponseCodeError` with json.""" +        json_data = {'hello': 'world'} +        error = api.ResponseCodeError( +            response=self.error_api_response, +            response_json=json_data, +        ) +        self.assertEqual(error.response_json, json_data) +        self.assertEqual(error.response_text, "") + +    def test_response_code_error_string_representation_with_nonempty_response_json(self): +        """Test the string representation of `ResponseCodeError` initialized with json.""" +        json_data = {'hello': 'world'} +        error = api.ResponseCodeError( +            response=self.error_api_response, +            response_json=json_data +        ) +        self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: {json_data}") + +    def test_response_code_error_initialization_with_text(self): +        """Test the initialization of `ResponseCodeError` with text.""" +        text_data = 'Lemon will eat your soul' +        error = api.ResponseCodeError( +            response=self.error_api_response, +            response_text=text_data, +        ) +        self.assertEqual(error.response_text, text_data) +        self.assertEqual(error.response_json, {}) + +    def test_response_code_error_string_representation_with_nonempty_response_text(self): +        """Test the string representation of `ResponseCodeError` initialized with text.""" +        text_data = 'Lemon will eat your soul' +        error = api.ResponseCodeError( +            response=self.error_api_response, +            response_text=text_data +        ) +        self.assertEqual(str(error), f"Status: {self.error_api_response.status} Response: {text_data}") + + +class LoggingHandlerTests(LoggingTestCase): +    """Tests the bot's API Log Handler.""" + +    @classmethod +    def setUpClass(cls): +        cls.debug_log_record = logging.LogRecord( +            name='my.logger', level=logging.DEBUG, +            pathname='my/logger.py', lineno=666, +            msg="Lemon wins", args=(), +            exc_info=None +        ) + +        cls.trace_log_record = logging.LogRecord( +            name='my.logger', level=logging.TRACE, +            pathname='my/logger.py', lineno=666, +            msg="This will not be logged", args=(), +            exc_info=None +        ) + +    def setUp(self): +        self.log_handler = api.APILoggingHandler(None) + +    def test_emit_appends_to_queue_with_stopped_event_loop(self): +        """Test if `APILoggingHandler.emit` appends to queue when the event loop is not running.""" +        with patch("bot.api.APILoggingHandler.ship_off") as ship_off: +            # Patch `ship_off` to ease testing against the return value of this coroutine. +            ship_off.return_value = 42 +            self.log_handler.emit(self.debug_log_record) + +        self.assertListEqual(self.log_handler.queue, [42]) + +    def test_emit_ignores_less_than_debug(self): +        """`APILoggingHandler.emit` should not queue logs with a log level lower than DEBUG.""" +        self.log_handler.emit(self.trace_log_record) +        self.assertListEqual(self.log_handler.queue, []) + +    def test_schedule_queued_tasks_for_empty_queue(self): +        """`APILoggingHandler` should not schedule anything when the queue is empty.""" +        with self.assertNotLogs(level=logging.DEBUG): +            self.log_handler.schedule_queued_tasks() + +    def test_schedule_queued_tasks_for_nonempty_queue(self): +        """`APILoggingHandler` should schedule logs when the queue is not empty.""" +        with self.assertLogs(level=logging.DEBUG) as logs, patch('asyncio.create_task') as create_task: +            self.log_handler.queue = [555] +            self.log_handler.schedule_queued_tasks() +            self.assertListEqual(self.log_handler.queue, []) +            create_task.assert_called_once_with(555) + +            [record] = logs.records +            self.assertEqual(record.message, "Scheduled 1 pending logging tasks.") +            self.assertEqual(record.levelno, logging.DEBUG) +            self.assertEqual(record.name, 'bot.api') +            self.assertIn('via_handler', record.__dict__) diff --git a/tests/bot/test_constants.py b/tests/bot/test_constants.py new file mode 100644 index 000000000..dae7c066c --- /dev/null +++ b/tests/bot/test_constants.py @@ -0,0 +1,26 @@ +import inspect +import unittest + +from bot import constants + + +class ConstantsTests(unittest.TestCase): +    """Tests for our constants.""" + +    def test_section_configuration_matches_type_specification(self): +        """The section annotations should match the actual types of the sections.""" + +        sections = ( +            cls +            for (name, cls) in inspect.getmembers(constants) +            if hasattr(cls, 'section') and isinstance(cls, type) +        ) +        for section in sections: +            for name, annotation in section.__annotations__.items(): +                with self.subTest(section=section, name=name, annotation=annotation): +                    value = getattr(section, name) + +                    if getattr(annotation, '_name', None) in ('Dict', 'List'): +                        self.skipTest("Cannot validate containers yet.") + +                    self.assertIsInstance(value, annotation) diff --git a/tests/bot/test_converters.py b/tests/bot/test_converters.py new file mode 100644 index 000000000..b2b78d9dd --- /dev/null +++ b/tests/bot/test_converters.py @@ -0,0 +1,273 @@ +import asyncio +import datetime +import unittest +from unittest.mock import MagicMock, patch + +from dateutil.relativedelta import relativedelta +from discord.ext.commands import BadArgument + +from bot.converters import ( +    Duration, +    ISODateTime, +    TagContentConverter, +    TagNameConverter, +    ValidPythonIdentifier, +) + + +class ConverterTests(unittest.TestCase): +    """Tests our custom argument converters.""" + +    @classmethod +    def setUpClass(cls): +        cls.context = MagicMock +        cls.context.author = 'bob' + +        cls.fixed_utc_now = datetime.datetime.fromisoformat('2019-01-01T00:00:00') + +    def test_tag_content_converter_for_valid(self): +        """TagContentConverter should return correct values for valid input.""" +        test_values = ( +            ('hello', 'hello'), +            ('  h ello  ', 'h ello'), +        ) + +        for content, expected_conversion in test_values: +            with self.subTest(content=content, expected_conversion=expected_conversion): +                conversion = asyncio.run(TagContentConverter.convert(self.context, content)) +                self.assertEqual(conversion, expected_conversion) + +    def test_tag_content_converter_for_invalid(self): +        """TagContentConverter should raise the proper exception for invalid input.""" +        test_values = ( +            ('', "Tag contents should not be empty, or filled with whitespace."), +            ('   ', "Tag contents should not be empty, or filled with whitespace."), +        ) + +        for value, exception_message in test_values: +            with self.subTest(tag_content=value, exception_message=exception_message): +                with self.assertRaises(BadArgument, msg=exception_message): +                    asyncio.run(TagContentConverter.convert(self.context, value)) + +    def test_tag_name_converter_for_valid(self): +        """TagNameConverter should return the correct values for valid tag names.""" +        test_values = ( +            ('tracebacks', 'tracebacks'), +            ('Tracebacks', 'tracebacks'), +            ('  Tracebacks  ', 'tracebacks'), +        ) + +        for name, expected_conversion in test_values: +            with self.subTest(name=name, expected_conversion=expected_conversion): +                conversion = asyncio.run(TagNameConverter.convert(self.context, name)) +                self.assertEqual(conversion, expected_conversion) + +    def test_tag_name_converter_for_invalid(self): +        """TagNameConverter should raise the correct exception for invalid tag names.""" +        test_values = ( +            ('👋', "Don't be ridiculous, you can't use that character!"), +            ('', "Tag names should not be empty, or filled with whitespace."), +            ('  ', "Tag names should not be empty, or filled with whitespace."), +            ('42', "Tag names can't be numbers."), +            ('x' * 128, "Are you insane? That's way too long!"), +        ) + +        for invalid_name, exception_message in test_values: +            with self.subTest(invalid_name=invalid_name, exception_message=exception_message): +                with self.assertRaises(BadArgument, msg=exception_message): +                    asyncio.run(TagNameConverter.convert(self.context, invalid_name)) + +    def test_valid_python_identifier_for_valid(self): +        """ValidPythonIdentifier returns valid identifiers unchanged.""" +        test_values = ('foo', 'lemon') + +        for name in test_values: +            with self.subTest(identifier=name): +                conversion = asyncio.run(ValidPythonIdentifier.convert(self.context, name)) +                self.assertEqual(name, conversion) + +    def test_valid_python_identifier_for_invalid(self): +        """ValidPythonIdentifier raises the proper exception for invalid identifiers.""" +        test_values = ('nested.stuff', '#####') + +        for name in test_values: +            with self.subTest(identifier=name): +                exception_message = f'`{name}` is not a valid Python identifier' +                with self.assertRaises(BadArgument, msg=exception_message): +                    asyncio.run(ValidPythonIdentifier.convert(self.context, name)) + +    def test_duration_converter_for_valid(self): +        """Duration returns the correct `datetime` for valid duration strings.""" +        test_values = ( +            # Simple duration strings +            ('1Y', {"years": 1}), +            ('1y', {"years": 1}), +            ('1year', {"years": 1}), +            ('1years', {"years": 1}), +            ('1m', {"months": 1}), +            ('1month', {"months": 1}), +            ('1months', {"months": 1}), +            ('1w', {"weeks": 1}), +            ('1W', {"weeks": 1}), +            ('1week', {"weeks": 1}), +            ('1weeks', {"weeks": 1}), +            ('1d', {"days": 1}), +            ('1D', {"days": 1}), +            ('1day', {"days": 1}), +            ('1days', {"days": 1}), +            ('1h', {"hours": 1}), +            ('1H', {"hours": 1}), +            ('1hour', {"hours": 1}), +            ('1hours', {"hours": 1}), +            ('1M', {"minutes": 1}), +            ('1minute', {"minutes": 1}), +            ('1minutes', {"minutes": 1}), +            ('1s', {"seconds": 1}), +            ('1S', {"seconds": 1}), +            ('1second', {"seconds": 1}), +            ('1seconds', {"seconds": 1}), + +            # Complex duration strings +            ( +                '1y1m1w1d1H1M1S', +                { +                    "years": 1, +                    "months": 1, +                    "weeks": 1, +                    "days": 1, +                    "hours": 1, +                    "minutes": 1, +                    "seconds": 1 +                } +            ), +            ('5y100S', {"years": 5, "seconds": 100}), +            ('2w28H', {"weeks": 2, "hours": 28}), + +            # Duration strings with spaces +            ('1 year 2 months', {"years": 1, "months": 2}), +            ('1d 2H', {"days": 1, "hours": 2}), +            ('1 week2 days', {"weeks": 1, "days": 2}), +        ) + +        converter = Duration() + +        for duration, duration_dict in test_values: +            expected_datetime = self.fixed_utc_now + relativedelta(**duration_dict) + +            with patch('bot.converters.datetime') as mock_datetime: +                mock_datetime.utcnow.return_value = self.fixed_utc_now + +                with self.subTest(duration=duration, duration_dict=duration_dict): +                    converted_datetime = asyncio.run(converter.convert(self.context, duration)) +                    self.assertEqual(converted_datetime, expected_datetime) + +    def test_duration_converter_for_invalid(self): +        """Duration raises the right exception for invalid duration strings.""" +        test_values = ( +            # Units in wrong order +            ('1d1w'), +            ('1s1y'), + +            # Duplicated units +            ('1 year 2 years'), +            ('1 M 10 minutes'), + +            # Unknown substrings +            ('1MVes'), +            ('1y3breads'), + +            # Missing amount +            ('ym'), + +            # Incorrect whitespace +            (" 1y"), +            ("1S "), +            ("1y  1m"), + +            # Garbage +            ('Guido van Rossum'), +            ('lemon lemon lemon lemon lemon lemon lemon'), +        ) + +        converter = Duration() + +        for invalid_duration in test_values: +            with self.subTest(invalid_duration=invalid_duration): +                exception_message = f'`{invalid_duration}` is not a valid duration string.' +                with self.assertRaises(BadArgument, msg=exception_message): +                    asyncio.run(converter.convert(self.context, invalid_duration)) + +    def test_isodatetime_converter_for_valid(self): +        """ISODateTime converter returns correct datetime for valid datetime string.""" +        test_values = ( +            # `YYYY-mm-ddTHH:MM:SSZ` | `YYYY-mm-dd HH:MM:SSZ` +            ('2019-09-02T02:03:05Z', datetime.datetime(2019, 9, 2, 2, 3, 5)), +            ('2019-09-02 02:03:05Z', datetime.datetime(2019, 9, 2, 2, 3, 5)), + +            # `YYYY-mm-ddTHH:MM:SS±HH:MM` | `YYYY-mm-dd HH:MM:SS±HH:MM` +            ('2019-09-02T03:18:05+01:15', datetime.datetime(2019, 9, 2, 2, 3, 5)), +            ('2019-09-02 03:18:05+01:15', datetime.datetime(2019, 9, 2, 2, 3, 5)), +            ('2019-09-02T00:48:05-01:15', datetime.datetime(2019, 9, 2, 2, 3, 5)), +            ('2019-09-02 00:48:05-01:15', datetime.datetime(2019, 9, 2, 2, 3, 5)), + +            # `YYYY-mm-ddTHH:MM:SS±HHMM` | `YYYY-mm-dd HH:MM:SS±HHMM` +            ('2019-09-02T03:18:05+0115', datetime.datetime(2019, 9, 2, 2, 3, 5)), +            ('2019-09-02 03:18:05+0115', datetime.datetime(2019, 9, 2, 2, 3, 5)), +            ('2019-09-02T00:48:05-0115', datetime.datetime(2019, 9, 2, 2, 3, 5)), +            ('2019-09-02 00:48:05-0115', datetime.datetime(2019, 9, 2, 2, 3, 5)), + +            # `YYYY-mm-ddTHH:MM:SS±HH` | `YYYY-mm-dd HH:MM:SS±HH` +            ('2019-09-02 03:03:05+01', datetime.datetime(2019, 9, 2, 2, 3, 5)), +            ('2019-09-02T01:03:05-01', datetime.datetime(2019, 9, 2, 2, 3, 5)), + +            # `YYYY-mm-ddTHH:MM:SS` | `YYYY-mm-dd HH:MM:SS` +            ('2019-09-02T02:03:05', datetime.datetime(2019, 9, 2, 2, 3, 5)), +            ('2019-09-02 02:03:05', datetime.datetime(2019, 9, 2, 2, 3, 5)), + +            # `YYYY-mm-ddTHH:MM` | `YYYY-mm-dd HH:MM` +            ('2019-11-12T09:15', datetime.datetime(2019, 11, 12, 9, 15)), +            ('2019-11-12 09:15', datetime.datetime(2019, 11, 12, 9, 15)), + +            # `YYYY-mm-dd` +            ('2019-04-01', datetime.datetime(2019, 4, 1)), + +            # `YYYY-mm` +            ('2019-02-01', datetime.datetime(2019, 2, 1)), + +            # `YYYY` +            ('2025', datetime.datetime(2025, 1, 1)), +        ) + +        converter = ISODateTime() + +        for datetime_string, expected_dt in test_values: +            with self.subTest(datetime_string=datetime_string, expected_dt=expected_dt): +                converted_dt = asyncio.run(converter.convert(self.context, datetime_string)) +                self.assertIsNone(converted_dt.tzinfo) +                self.assertEqual(converted_dt, expected_dt) + +    def test_isodatetime_converter_for_invalid(self): +        """ISODateTime converter raises the correct exception for invalid datetime strings.""" +        test_values = ( +            # Make sure it doesn't interfere with the Duration converter +            ('1Y'), +            ('1d'), +            ('1H'), + +            # Check if it fails when only providing the optional time part +            ('10:10:10'), +            ('10:00'), + +            # Invalid date format +            ('19-01-01'), + +            # Other non-valid strings +            ('fisk the tag master'), +        ) + +        converter = ISODateTime() +        for datetime_string in test_values: +            with self.subTest(datetime_string=datetime_string): +                exception_message = f"`{datetime_string}` is not a valid ISO-8601 datetime string" +                with self.assertRaises(BadArgument, msg=exception_message): +                    asyncio.run(converter.convert(self.context, datetime_string)) diff --git a/tests/test_pagination.py b/tests/bot/test_pagination.py index 11d6541ae..0a734b505 100644 --- a/tests/test_pagination.py +++ b/tests/bot/test_pagination.py @@ -1,28 +1,35 @@  from unittest import TestCase -import pytest -  from bot import pagination  class LinePaginatorTests(TestCase): +    """Tests functionality of the `LinePaginator`.""" +      def setUp(self): +        """Create a paginator for the test method."""          self.paginator = pagination.LinePaginator(prefix='', suffix='', max_size=30)      def test_add_line_raises_on_too_long_lines(self): +        """`add_line` should raise a `RuntimeError` for too long lines."""          message = f"Line exceeds maximum page size {self.paginator.max_size - 2}" -        with pytest.raises(RuntimeError, match=message): +        with self.assertRaises(RuntimeError, msg=message):              self.paginator.add_line('x' * self.paginator.max_size)      def test_add_line_works_on_small_lines(self): +        """`add_line` should allow small lines to be added."""          self.paginator.add_line('x' * (self.paginator.max_size - 3))  class ImagePaginatorTests(TestCase): +    """Tests functionality of the `ImagePaginator`.""" +      def setUp(self): +        """Create a paginator for the test method."""          self.paginator = pagination.ImagePaginator()      def test_add_image_appends_image(self): +        """`add_image` appends the image to the image list."""          image = 'lemon'          self.paginator.add_image(image) diff --git a/tests/bot/utils/__init__.py b/tests/bot/utils/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/tests/bot/utils/__init__.py diff --git a/tests/bot/utils/test_checks.py b/tests/bot/utils/test_checks.py new file mode 100644 index 000000000..22dc93073 --- /dev/null +++ b/tests/bot/utils/test_checks.py @@ -0,0 +1,43 @@ +import unittest + +from bot.utils import checks +from tests.helpers import MockContext, MockRole + + +class ChecksTests(unittest.TestCase): +    """Tests the check functions defined in `bot.checks`.""" + +    def setUp(self): +        self.ctx = MockContext() + +    def test_with_role_check_without_guild(self): +        """`with_role_check` returns `False` if `Context.guild` is None.""" +        self.ctx.guild = None +        self.assertFalse(checks.with_role_check(self.ctx)) + +    def test_with_role_check_without_required_roles(self): +        """`with_role_check` returns `False` if `Context.author` lacks the required role.""" +        self.ctx.author.roles = [] +        self.assertFalse(checks.with_role_check(self.ctx)) + +    def test_with_role_check_with_guild_and_required_role(self): +        """`with_role_check` returns `True` if `Context.author` has the required role.""" +        self.ctx.author.roles.append(MockRole(role_id=10)) +        self.assertTrue(checks.with_role_check(self.ctx, 10)) + +    def test_without_role_check_without_guild(self): +        """`without_role_check` should return `False` when `Context.guild` is None.""" +        self.ctx.guild = None +        self.assertFalse(checks.without_role_check(self.ctx)) + +    def test_without_role_check_returns_false_with_unwanted_role(self): +        """`without_role_check` returns `False` if `Context.author` has unwanted role.""" +        role_id = 42 +        self.ctx.author.roles.append(MockRole(role_id=role_id)) +        self.assertFalse(checks.without_role_check(self.ctx, role_id)) + +    def test_without_role_check_returns_true_without_unwanted_role(self): +        """`without_role_check` returns `True` if `Context.author` does not have unwanted role.""" +        role_id = 42 +        self.ctx.author.roles.append(MockRole(role_id=role_id)) +        self.assertTrue(checks.without_role_check(self.ctx, role_id + 10)) diff --git a/tests/cogs/sync/test_roles.py b/tests/cogs/sync/test_roles.py deleted file mode 100644 index c561ba447..000000000 --- a/tests/cogs/sync/test_roles.py +++ /dev/null @@ -1,103 +0,0 @@ -from bot.cogs.sync.syncers import Role, get_roles_for_sync - - -def test_get_roles_for_sync_empty_return_for_equal_roles(): -    api_roles = {Role(id=41, name='name', colour=33, permissions=0x8, position=1)} -    guild_roles = {Role(id=41, name='name', colour=33, permissions=0x8, position=1)} - -    assert get_roles_for_sync(guild_roles, api_roles) == (set(), set(), set()) - - -def test_get_roles_for_sync_returns_roles_to_update_with_non_id_diff(): -    api_roles = {Role(id=41, name='old name', colour=35, permissions=0x8, position=1)} -    guild_roles = {Role(id=41, name='new name', colour=33, permissions=0x8, position=2)} - -    assert get_roles_for_sync(guild_roles, api_roles) == ( -        set(), -        guild_roles, -        set(), -    ) - - -def test_get_roles_only_returns_roles_that_require_update(): -    api_roles = { -        Role(id=41, name='old name', colour=33, permissions=0x8, position=1), -        Role(id=53, name='other role', colour=55, permissions=0, position=3) -    } -    guild_roles = { -        Role(id=41, name='new name', colour=35, permissions=0x8, position=2), -        Role(id=53, name='other role', colour=55, permissions=0, position=3) -    } - -    assert get_roles_for_sync(guild_roles, api_roles) == ( -        set(), -        {Role(id=41, name='new name', colour=35, permissions=0x8, position=2)}, -        set(), -    ) - - -def test_get_roles_returns_new_roles_in_first_tuple_element(): -    api_roles = { -        Role(id=41, name='name', colour=35, permissions=0x8, position=1), -    } -    guild_roles = { -        Role(id=41, name='name', colour=35, permissions=0x8, position=1), -        Role(id=53, name='other role', colour=55, permissions=0, position=2) -    } - -    assert get_roles_for_sync(guild_roles, api_roles) == ( -        {Role(id=53, name='other role', colour=55, permissions=0, position=2)}, -        set(), -        set(), -    ) - - -def test_get_roles_returns_roles_to_update_and_new_roles(): -    api_roles = { -        Role(id=41, name='old name', colour=35, permissions=0x8, position=1), -    } -    guild_roles = { -        Role(id=41, name='new name', colour=40, permissions=0x16, position=2), -        Role(id=53, name='other role', colour=55, permissions=0, position=3) -    } - -    assert get_roles_for_sync(guild_roles, api_roles) == ( -        {Role(id=53, name='other role', colour=55, permissions=0, position=3)}, -        {Role(id=41, name='new name', colour=40, permissions=0x16, position=2)}, -        set(), -    ) - - -def test_get_roles_returns_roles_to_delete(): -    api_roles = { -        Role(id=41, name='name', colour=35, permissions=0x8, position=1), -        Role(id=61, name='to delete', colour=99, permissions=0x9, position=2), -    } -    guild_roles = { -        Role(id=41, name='name', colour=35, permissions=0x8, position=1), -    } - -    assert get_roles_for_sync(guild_roles, api_roles) == ( -        set(), -        set(), -        {Role(id=61, name='to delete', colour=99, permissions=0x9, position=2)}, -    ) - - -def test_get_roles_returns_roles_to_delete_update_and_new_roles(): -    api_roles = { -        Role(id=41, name='not changed', colour=35, permissions=0x8, position=1), -        Role(id=61, name='to delete', colour=99, permissions=0x9, position=2), -        Role(id=71, name='to update', colour=99, permissions=0x9, position=3), -    } -    guild_roles = { -        Role(id=41, name='not changed', colour=35, permissions=0x8, position=1), -        Role(id=81, name='to create', colour=99, permissions=0x9, position=4), -        Role(id=71, name='updated', colour=101, permissions=0x5, position=3), -    } - -    assert get_roles_for_sync(guild_roles, api_roles) == ( -        {Role(id=81, name='to create', colour=99, permissions=0x9, position=4)}, -        {Role(id=71, name='updated', colour=101, permissions=0x5, position=3)}, -        {Role(id=61, name='to delete', colour=99, permissions=0x9, position=2)}, -    ) diff --git a/tests/cogs/sync/test_users.py b/tests/cogs/sync/test_users.py deleted file mode 100644 index a863ae35b..000000000 --- a/tests/cogs/sync/test_users.py +++ /dev/null @@ -1,69 +0,0 @@ -from bot.cogs.sync.syncers import User, get_users_for_sync - - -def fake_user(**kwargs): -    kwargs.setdefault('id', 43) -    kwargs.setdefault('name', 'bob the test man') -    kwargs.setdefault('discriminator', 1337) -    kwargs.setdefault('avatar_hash', None) -    kwargs.setdefault('roles', (666,)) -    kwargs.setdefault('in_guild', True) -    return User(**kwargs) - - -def test_get_users_for_sync_returns_nothing_for_empty_params(): -    assert get_users_for_sync({}, {}) == (set(), set()) - - -def test_get_users_for_sync_returns_nothing_for_equal_users(): -    api_users = {43: fake_user()} -    guild_users = {43: fake_user()} - -    assert get_users_for_sync(guild_users, api_users) == (set(), set()) - - -def test_get_users_for_sync_returns_users_to_update_on_non_id_field_diff(): -    api_users = {43: fake_user()} -    guild_users = {43: fake_user(name='new fancy name')} - -    assert get_users_for_sync(guild_users, api_users) == ( -        set(), -        {fake_user(name='new fancy name')} -    ) - - -def test_get_users_for_sync_returns_users_to_create_with_new_ids_on_guild(): -    api_users = {43: fake_user()} -    guild_users = {43: fake_user(), 63: fake_user(id=63)} - -    assert get_users_for_sync(guild_users, api_users) == ( -        {fake_user(id=63)}, -        set() -    ) - - -def test_get_users_for_sync_updates_in_guild_field_on_user_leave(): -    api_users = {43: fake_user(), 63: fake_user(id=63)} -    guild_users = {43: fake_user()} - -    assert get_users_for_sync(guild_users, api_users) == ( -        set(), -        {fake_user(id=63, in_guild=False)} -    ) - - -def test_get_users_for_sync_updates_and_creates_users_as_needed(): -    api_users = {43: fake_user()} -    guild_users = {63: fake_user(id=63)} - -    assert get_users_for_sync(guild_users, api_users) == ( -        {fake_user(id=63)}, -        {fake_user(in_guild=False)} -    ) - - -def test_get_users_for_sync_does_not_duplicate_update_users(): -    api_users = {43: fake_user(in_guild=False)} -    guild_users = {} - -    assert get_users_for_sync(guild_users, api_users) == (set(), set()) diff --git a/tests/cogs/test_antispam.py b/tests/cogs/test_antispam.py deleted file mode 100644 index 67900b275..000000000 --- a/tests/cogs/test_antispam.py +++ /dev/null @@ -1,30 +0,0 @@ -import pytest - -from bot.cogs import antispam - - -def test_default_antispam_config_is_valid(): -    validation_errors = antispam.validate_config() -    assert not validation_errors - - -    ('config', 'expected'), -    ( -        ( -            {'invalid-rule': {}}, -            {'invalid-rule': "`invalid-rule` is not recognized as an antispam rule."} -        ), -        ( -            {'burst': {'interval': 10}}, -            {'burst': "Key `max` is required but not set for rule `burst`"} -        ), -        ( -            {'burst': {'max': 10}}, -            {'burst': "Key `interval` is required but not set for rule `burst`"} -        ) -    ) -) -def test_invalid_antispam_config_returns_validation_errors(config, expected): -    validation_errors = antispam.validate_config(config) -    assert validation_errors == expected diff --git a/tests/cogs/test_information.py b/tests/cogs/test_information.py deleted file mode 100644 index 184bd2595..000000000 --- a/tests/cogs/test_information.py +++ /dev/null @@ -1,211 +0,0 @@ -import asyncio -import logging -import textwrap -from datetime import datetime -from unittest.mock import MagicMock, patch - -import pytest -from discord import ( -    CategoryChannel, -    Colour, -    Permissions, -    Role, -    TextChannel, -    VoiceChannel, -) - -from bot.cogs import information -from bot.constants import Emojis -from bot.decorators import InChannelCheckFailure -from tests.helpers import AsyncMock - - -def cog(simple_bot): -    return information.Information(simple_bot) - - -def role(name: str, id_: int): -    r = MagicMock() -    r.name = name -    r.id = id_ -    r.mention = f'&{name}' -    return r - - -def member(status: str): -    m = MagicMock() -    m.status = status -    return m - - -def ctx(moderator_role, simple_ctx): -    simple_ctx.author.roles = [moderator_role] -    simple_ctx.guild.created_at = datetime(2001, 1, 1) -    simple_ctx.send = AsyncMock() -    return simple_ctx - - -def test_roles_info_command(cog, ctx): -    everyone_role = MagicMock() -    everyone_role.name = '@everyone'  # should be excluded in the output -    ctx.author.roles.append(everyone_role) -    ctx.guild.roles = ctx.author.roles - -    cog.roles_info.can_run = AsyncMock() -    cog.roles_info.can_run.return_value = True - -    coroutine = cog.roles_info.callback(cog, ctx) - -    assert asyncio.run(coroutine) is None  # no rval -    ctx.send.assert_called_once() -    _, kwargs = ctx.send.call_args -    embed = kwargs.pop('embed') -    assert embed.title == "Role information" -    assert embed.colour == Colour.blurple() -    assert embed.description == f"`{ctx.guild.roles[0].id}` - {ctx.guild.roles[0].mention}\n" -    assert embed.footer.text == "Total roles: 1" - - -def test_role_info_command(cog, ctx): -    dummy_role = MagicMock(spec=Role) -    dummy_role.name = "Dummy" -    dummy_role.colour = Colour.blurple() -    dummy_role.id = 112233445566778899 -    dummy_role.position = 10 -    dummy_role.permissions = Permissions(0) -    dummy_role.members = [ctx.author] - -    admin_role = MagicMock(spec=Role) -    admin_role.name = "Admin" -    admin_role.colour = Colour.red() -    admin_role.id = 998877665544332211 -    admin_role.position = 3 -    admin_role.permissions = Permissions(0) -    admin_role.members = [ctx.author] - -    ctx.guild.roles = [dummy_role, admin_role] - -    cog.role_info.can_run = AsyncMock() -    cog.role_info.can_run.return_value = True - -    coroutine = cog.role_info.callback(cog, ctx, dummy_role, admin_role) - -    assert asyncio.run(coroutine) is None - -    assert ctx.send.call_count == 2 - -    (_, dummy_kwargs), (_, admin_kwargs) = ctx.send.call_args_list - -    dummy_embed = dummy_kwargs["embed"] -    admin_embed = admin_kwargs["embed"] - -    assert dummy_embed.title == "Dummy info" -    assert dummy_embed.colour == Colour.blurple() - -    assert dummy_embed.fields[0].value == str(dummy_role.id) -    assert dummy_embed.fields[1].value == f"#{dummy_role.colour.value:0>6x}" -    assert dummy_embed.fields[2].value == "0.63 0.48 218" -    assert dummy_embed.fields[3].value == "1" -    assert dummy_embed.fields[4].value == "10" -    assert dummy_embed.fields[5].value == "0" - -    assert admin_embed.title == "Admin info" -    assert admin_embed.colour == Colour.red() - -# There is no argument passed in here that we can use to test, -# so the return value would change constantly. -@patch('bot.cogs.information.time_since') -def test_server_info_command(time_since_patch, cog, ctx, moderator_role): -    time_since_patch.return_value = '2 days ago' - -    ctx.guild.created_at = datetime(2001, 1, 1) -    ctx.guild.features = ('lemons', 'apples') -    ctx.guild.region = 'The Moon' -    ctx.guild.roles = [moderator_role] -    ctx.guild.channels = [ -        TextChannel( -            state={}, -            guild=ctx.guild, -            data={'id': 42, 'name': 'lemons-offering', 'position': 22, 'type': 'text'} -        ), -        CategoryChannel( -            state={}, -            guild=ctx.guild, -            data={'id': 5125, 'name': 'the-lemon-collection', 'position': 22, 'type': 'category'} -        ), -        VoiceChannel( -            state={}, -            guild=ctx.guild, -            data={'id': 15290, 'name': 'listen-to-lemon', 'position': 22, 'type': 'voice'} -        ) -    ] -    ctx.guild.members = [ -        member('online'), member('online'), -        member('idle'), -        member('dnd'), member('dnd'), member('dnd'), member('dnd'), -        member('offline'), member('offline'), member('offline') -    ] -    ctx.guild.member_count = 1_234 -    ctx.guild.icon_url = 'a-lemon.png' - -    coroutine = cog.server_info.callback(cog, ctx) -    assert asyncio.run(coroutine) is None  # no rval - -    time_since_patch.assert_called_once_with(ctx.guild.created_at, precision='days') -    _, kwargs = ctx.send.call_args -    embed = kwargs.pop('embed') -    assert embed.colour == Colour.blurple() -    assert embed.description == textwrap.dedent(f""" -        **Server information** -        Created: {time_since_patch.return_value} -        Voice region: {ctx.guild.region} -        Features: {', '.join(ctx.guild.features)} - -        **Counts** -        Members: {ctx.guild.member_count:,} -        Roles: {len(ctx.guild.roles)} -        Text: 1 -        Voice: 1 -        Channel categories: 1 - -        **Members** -        {Emojis.status_online} 2 -        {Emojis.status_idle} 1 -        {Emojis.status_dnd} 4 -        {Emojis.status_offline} 3 -        """) -    assert embed.thumbnail.url == 'a-lemon.png' - - -def test_user_info_on_other_users_from_non_moderator(ctx, cog): -    ctx.author = MagicMock() -    ctx.author.__eq__.return_value = False -    ctx.author.roles = [] -    coroutine = cog.user_info.callback(cog, ctx, user='scragly')  # skip checks, pass args - -    assert asyncio.run(coroutine) is None  # no rval -    ctx.send.assert_called_once_with( -        "You may not use this command on users other than yourself." -    ) - - -def test_user_info_in_wrong_channel_from_non_moderator(ctx, cog): -    ctx.author = MagicMock() -    ctx.author.__eq__.return_value = False -    ctx.author.roles = [] - -    coroutine = cog.user_info.callback(cog, ctx) -    message = 'Sorry, but you may only use this command within <#267659945086812160>.' -    with pytest.raises(InChannelCheckFailure, match=message): -        assert asyncio.run(coroutine) is None  # no rval - - -def test_setup(simple_bot, caplog): -    information.setup(simple_bot) -    simple_bot.add_cog.assert_called_once() -    [record] = caplog.records - -    assert record.message == "Cog loaded: Information" -    assert record.levelno == logging.INFO diff --git a/tests/cogs/test_security.py b/tests/cogs/test_security.py deleted file mode 100644 index 1efb460fe..000000000 --- a/tests/cogs/test_security.py +++ /dev/null @@ -1,54 +0,0 @@ -import logging -from unittest.mock import MagicMock - -import pytest -from discord.ext.commands import NoPrivateMessage - -from bot.cogs import security - - -def cog(): -    bot = MagicMock() -    return security.Security(bot) - - -def context(): -    return MagicMock() - - -def test_check_additions(cog): -    cog.bot.check.assert_any_call(cog.check_on_guild) -    cog.bot.check.assert_any_call(cog.check_not_bot) - - -def test_check_not_bot_for_humans(cog, context): -    context.author.bot = False -    assert cog.check_not_bot(context) - - -def test_check_not_bot_for_robots(cog, context): -    context.author.bot = True -    assert not cog.check_not_bot(context) - - -def test_check_on_guild_outside_of_guild(cog, context): -    context.guild = None - -    with pytest.raises(NoPrivateMessage, match="This command cannot be used in private messages."): -        cog.check_on_guild(context) - - -def test_check_on_guild_on_guild(cog, context): -    context.guild = "lemon's lemonade stand" -    assert cog.check_on_guild(context) - - -def test_security_cog_load(caplog): -    bot = MagicMock() -    security.setup(bot) -    bot.add_cog.assert_called_once() -    [record] = caplog.records -    assert record.message == "Cog loaded: Security" -    assert record.levelno == logging.INFO diff --git a/tests/cogs/test_token_remover.py b/tests/cogs/test_token_remover.py deleted file mode 100644 index 9d46b3a05..000000000 --- a/tests/cogs/test_token_remover.py +++ /dev/null @@ -1,133 +0,0 @@ -import asyncio -from unittest.mock import MagicMock - -import pytest -from discord import Colour - -from bot.cogs.token_remover import ( -    DELETION_MESSAGE_TEMPLATE, -    TokenRemover, -    setup as setup_cog, -) -from bot.constants import Channels, Colours, Event, Icons -from tests.helpers import AsyncMock - - -def token_remover(): -    bot = MagicMock() -    bot.get_cog.return_value = MagicMock() -    bot.get_cog.return_value.send_log_message = AsyncMock() -    return TokenRemover(bot=bot) - - -def message(): -    message = MagicMock() -    message.author.__str__.return_value = 'lemon' -    message.author.bot = False -    message.author.avatar_url_as.return_value = 'picture-lemon.png' -    message.author.id = 42 -    message.author.mention = '@lemon' -    message.channel.send = AsyncMock() -    message.channel.mention = '#lemonade-stand' -    message.content = '' -    message.delete = AsyncMock() -    message.id = 555 -    return message - - -    ('content', 'expected'), -    ( -        ('MTIz', True),  # 123 -        ('YWJj', False),  # abc -    ) -) -def test_is_valid_user_id(content: str, expected: bool): -    assert TokenRemover.is_valid_user_id(content) is expected - - -    ('content', 'expected'), -    ( -        ('DN9r_A', True),  # stolen from dapi, thanks to the author of the 'token' tag! -        ('MTIz', False),  # 123 -    ) -) -def test_is_valid_timestamp(content: str, expected: bool): -    assert TokenRemover.is_valid_timestamp(content) is expected - - -def test_mod_log_property(token_remover): -    token_remover.bot.get_cog.return_value = 'lemon' -    assert token_remover.mod_log == 'lemon' -    token_remover.bot.get_cog.assert_called_once_with('ModLog') - - -def test_ignores_bot_messages(token_remover, message): -    message.author.bot = True -    coroutine = token_remover.on_message(message) -    assert asyncio.run(coroutine) is None - - [email protected]('content', ('', 'lemon wins')) -def test_ignores_messages_without_tokens(token_remover, message, content): -    message.content = content -    coroutine = token_remover.on_message(message) -    assert asyncio.run(coroutine) is None - - [email protected]('content', ('foo.bar.baz', 'x.y.')) -def test_ignores_invalid_tokens(token_remover, message, content): -    message.content = content -    coroutine = token_remover.on_message(message) -    assert asyncio.run(coroutine) is None - - -    'content, censored_token', -    ( -        ('MTIz.DN9R_A.xyz', 'MTIz.DN9R_A.xxx'), -    ) -) -def test_censors_valid_tokens( -    token_remover, message, content, censored_token, caplog -): -    message.content = content -    coroutine = token_remover.on_message(message) -    assert asyncio.run(coroutine) is None  # still no rval - -    # asyncio logs some stuff about its reactor, discard it -    [_, record] = caplog.records -    assert record.message == ( -        "Censored a seemingly valid token sent by lemon (`42`) in #lemonade-stand, " -        f"token was `{censored_token}`" -    ) - -    message.delete.assert_called_once_with() -    message.channel.send.assert_called_once_with( -        DELETION_MESSAGE_TEMPLATE.format(mention='@lemon') -    ) -    token_remover.bot.get_cog.assert_called_with('ModLog') -    message.author.avatar_url_as.assert_called_once_with(static_format='png') - -    mod_log = token_remover.bot.get_cog.return_value -    mod_log.ignore.assert_called_once_with(Event.message_delete, message.id) -    mod_log.send_log_message.assert_called_once_with( -        icon_url=Icons.token_removed, -        colour=Colour(Colours.soft_red), -        title="Token removed!", -        text=record.message, -        thumbnail='picture-lemon.png', -        channel_id=Channels.mod_alerts -    ) - - -def test_setup(caplog): -    bot = MagicMock() -    setup_cog(bot) -    [record] = caplog.records - -    bot.add_cog.assert_called_once() -    assert record.message == "Cog loaded: TokenRemover" diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index d3de4484d..000000000 --- a/tests/conftest.py +++ /dev/null @@ -1,32 +0,0 @@ -from unittest.mock import MagicMock - -import pytest - -from bot.constants import Roles -from tests.helpers import AsyncMock - - -def moderator_role(): -    mock = MagicMock() -    mock.id = Roles.moderator -    mock.name = 'Moderator' -    mock.mention = f'&{mock.name}' -    return mock - - -def simple_bot(): -    mock = MagicMock() -    mock._before_invoke = AsyncMock() -    mock._after_invoke = AsyncMock() -    mock.can_run = AsyncMock() -    mock.can_run.return_value = True -    return mock - - -def simple_ctx(simple_bot): -    mock = MagicMock() -    mock.bot = simple_bot -    return mock diff --git a/tests/helpers.py b/tests/helpers.py index 25059fa3a..892d42e6c 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,27 +1,18 @@ +from __future__ import annotations +  import asyncio  import functools -from unittest.mock import MagicMock - - -__all__ = ('AsyncMock', 'async_test') - +import unittest.mock +from typing import Iterable, Optional -# TODO: Remove me on 3.8 -# Allows you to mock a coroutine. Since the default `__call__` of `MagicMock` -# is not a coroutine, trying to mock a coroutine with it will result in errors -# as the default `__call__` is not awaitable. Use this class for monkeypatching -# coroutines instead. -class AsyncMock(MagicMock): -    async def __call__(self, *args, **kwargs): -        return super(AsyncMock, self).__call__(*args, **kwargs) +import discord +from discord.ext.commands import Bot, Context  def async_test(wrapped):      """      Run a test case via asyncio. -      Example: -          >>> @async_test          ... async def lemon_wins():          ...     assert True @@ -31,3 +22,407 @@ def async_test(wrapped):      def wrapper(*args, **kwargs):          return asyncio.run(wrapped(*args, **kwargs))      return wrapper + + +# TODO: Remove me in Python 3.8 +class AsyncMock(unittest.mock.MagicMock): +    """ +    A MagicMock subclass to mock async callables. + +    Python 3.8 will introduce an AsyncMock class in the standard library that will have some more +    features; this stand-in only overwrites the `__call__` method to an async version. +    """ + +    async def __call__(self, *args, **kwargs): +        return super(AsyncMock, self).__call__(*args, **kwargs) + + +class HashableMixin(discord.mixins.EqualityComparable): +    """ +    Mixin that provides similar hashing and equality functionality as discord.py's `Hashable` mixin. + +    Note: discord.py`s `Hashable` mixin bit-shifts `self.id` (`>> 22`); to prevent hash-collisions +    for the relative small `id` integers we generally use in tests, this bit-shift is omitted. +    """ + +    def __hash__(self): +        return self.id + + +class ColourMixin: +    """A mixin for Mocks that provides the aliasing of color->colour like discord.py does.""" + +    @property +    def color(self) -> discord.Colour: +        return self.colour + +    @color.setter +    def color(self, color: discord.Colour) -> None: +        self.colour = color + + +class AttributeMock: +    """Ensures attributes of our mock types will be instantiated with the correct mock type.""" + +    def __new__(cls, *args, **kwargs): +        """Stops the regular parent class from propagating to newly mocked attributes.""" +        if 'parent' in kwargs: +            return cls.attribute_mocktype(*args, **kwargs) + +        return super().__new__(cls) + + +# Create a guild instance to get a realistic Mock of `discord.Guild` +guild_data = { +    'id': 1, +    'name': 'guild', +    'region': 'Europe', +    'verification_level': 2, +    'default_notications': 1, +    'afk_timeout': 100, +    'icon': "icon.png", +    'banner': 'banner.png', +    'mfa_level': 1, +    'splash': 'splash.png', +    'system_channel_id': 464033278631084042, +    'description': 'mocking is fun', +    'max_presences': 10_000, +    'max_members': 100_000, +    'preferred_locale': 'UTC', +    'owner_id': 1, +    'afk_channel_id': 464033278631084042, +} +guild_instance = discord.Guild(data=guild_data, state=unittest.mock.MagicMock()) + + +class MockGuild(AttributeMock, unittest.mock.Mock, HashableMixin): +    """ +    A `Mock` subclass to mock `discord.Guild` objects. + +    A MockGuild instance will follow the specifications of a `discord.Guild` instance. This means +    that if the code you're testing tries to access an attribute or method that normally does not +    exist for a `discord.Guild` object this will raise an `AttributeError`. This is to make sure our +    tests fail if the code we're testing uses a `discord.Guild` object in the wrong way. + +    One restriction of that is that if the code tries to access an attribute that normally does not +    exist for `discord.Guild` instance but was added dynamically, this will raise an exception with +    the mocked object. To get around that, you can set the non-standard attribute explicitly for the +    instance of `MockGuild`: + +    >>> guild = MockGuild() +    >>> guild.attribute_that_normally_does_not_exist = unittest.mock.MagicMock() + +    In addition to attribute simulation, mocked guild object will pass an `isinstance` check against +    `discord.Guild`: + +    >>> guild = MockGuild() +    >>> isinstance(guild, discord.Guild) +    True + +    For more info, see the `Mocking` section in `tests/README.md`. +    """ + +    attribute_mocktype = unittest.mock.MagicMock + +    def __init__( +        self, +        guild_id: int = 1, +        roles: Optional[Iterable[MockRole]] = None, +        members: Optional[Iterable[MockMember]] = None, +        **kwargs, +    ) -> None: +        super().__init__(spec=guild_instance, **kwargs) + +        self.id = guild_id + +        self.roles = [MockRole("@everyone", 1)] +        if roles: +            self.roles.extend(roles) + +        self.members = [] +        if members: +            self.members.extend(members) + +        # `discord.Guild` coroutines +        self.create_category_channel = AsyncMock() +        self.ban = AsyncMock() +        self.bans = AsyncMock() +        self.create_category = AsyncMock() +        self.create_custom_emoji = AsyncMock() +        self.create_role = AsyncMock() +        self.create_text_channel = AsyncMock() +        self.create_voice_channel = AsyncMock() +        self.delete = AsyncMock() +        self.edit = AsyncMock() +        self.estimate_pruned_members = AsyncMock() +        self.fetch_ban = AsyncMock() +        self.fetch_channels = AsyncMock() +        self.fetch_emoji = AsyncMock() +        self.fetch_emojis = AsyncMock() +        self.fetch_member = AsyncMock() +        self.invites = AsyncMock() +        self.kick = AsyncMock() +        self.leave = AsyncMock() +        self.prune_members = AsyncMock() +        self.unban = AsyncMock() +        self.vanity_invite = AsyncMock() +        self.webhooks = AsyncMock() +        self.widget = AsyncMock() + + +# Create a Role instance to get a realistic Mock of `discord.Role` +role_data = {'name': 'role', 'id': 1} +role_instance = discord.Role(guild=guild_instance, state=unittest.mock.MagicMock(), data=role_data) + + +class MockRole(AttributeMock, unittest.mock.Mock, ColourMixin, HashableMixin): +    """ +    A Mock subclass to mock `discord.Role` objects. + +    Instances of this class will follow the specifications of `discord.Role` instances. For more +    information, see the `MockGuild` docstring. +    """ + +    attribute_mocktype = unittest.mock.MagicMock + +    def __init__(self, name: str = "role", role_id: int = 1, position: int = 1, **kwargs) -> None: +        super().__init__(spec=role_instance, **kwargs) + +        self.name = name +        self.id = role_id +        self.position = position +        self.mention = f'&{self.name}' + +        # 'discord.Role' coroutines +        self.delete = AsyncMock() +        self.edit = AsyncMock() + +    def __lt__(self, other): +        """Simplified position-based comparisons similar to those of `discord.Role`.""" +        return self.position < other.position + + +# Create a Member instance to get a realistic Mock of `discord.Member` +member_data = {'user': 'lemon', 'roles': [1]} +state_mock = unittest.mock.MagicMock() +member_instance = discord.Member(data=member_data, guild=guild_instance, state=state_mock) + + +class MockMember(AttributeMock, unittest.mock.Mock, ColourMixin, HashableMixin): +    """ +    A Mock subclass to mock Member objects. + +    Instances of this class will follow the specifications of `discord.Member` instances. For more +    information, see the `MockGuild` docstring. +    """ + +    attribute_mocktype = unittest.mock.MagicMock + +    def __init__( +        self, +        name: str = "member", +        user_id: int = 1, +        roles: Optional[Iterable[MockRole]] = None, +        **kwargs, +    ) -> None: +        super().__init__(spec=member_instance, **kwargs) + +        self.name = name +        self.id = user_id + +        self.roles = [MockRole("@everyone", 1)] +        if roles: +            self.roles.extend(roles) + +        self.mention = f"@{self.name}" + +        # `discord.Member` coroutines +        self.add_roles = AsyncMock() +        self.ban = AsyncMock() +        self.edit = AsyncMock() +        self.fetch_message = AsyncMock() +        self.kick = AsyncMock() +        self.move_to = AsyncMock() +        self.pins = AsyncMock() +        self.remove_roles = AsyncMock() +        self.send = AsyncMock() +        self.trigger_typing = AsyncMock() +        self.unban = AsyncMock() + + +# Create a Bot instance to get a realistic MagicMock of `discord.ext.commands.Bot` +bot_instance = Bot(command_prefix=unittest.mock.MagicMock()) + + +class MockBot(AttributeMock, unittest.mock.MagicMock): +    """ +    A MagicMock subclass to mock Bot objects. + +    Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances. +    For more information, see the `MockGuild` docstring. +    """ + +    attribute_mocktype = unittest.mock.MagicMock + +    def __init__(self, **kwargs) -> None: +        super().__init__(spec=bot_instance, **kwargs) + +        # `discord.ext.commands.Bot` coroutines +        self._before_invoke = AsyncMock() +        self._after_invoke = AsyncMock() +        self.application_info = AsyncMock() +        self.change_presence = AsyncMock() +        self.connect = AsyncMock() +        self.close = AsyncMock() +        self.create_guild = AsyncMock() +        self.delete_invite = AsyncMock() +        self.fetch_channel = AsyncMock() +        self.fetch_guild = AsyncMock() +        self.fetch_guilds = AsyncMock() +        self.fetch_invite = AsyncMock() +        self.fetch_user = AsyncMock() +        self.fetch_user_profile = AsyncMock() +        self.fetch_webhook = AsyncMock() +        self.fetch_widget = AsyncMock() +        self.get_context = AsyncMock() +        self.get_prefix = AsyncMock() +        self.invoke = AsyncMock() +        self.is_owner = AsyncMock() +        self.login = AsyncMock() +        self.logout = AsyncMock() +        self.on_command_error = AsyncMock() +        self.on_error = AsyncMock() +        self.process_commands = AsyncMock() +        self.request_offline_members = AsyncMock() +        self.start = AsyncMock() +        self.wait_until_ready = AsyncMock() +        self.wait_for = AsyncMock() + + +# Create a Context instance to get a realistic MagicMock of `discord.ext.commands.Context` +context_instance = Context(message=unittest.mock.MagicMock(), prefix=unittest.mock.MagicMock()) + + +class MockContext(AttributeMock, unittest.mock.MagicMock): +    """ +    A MagicMock subclass to mock Context objects. + +    Instances of this class will follow the specifications of `discord.ext.commands.Context` +    instances. For more information, see the `MockGuild` docstring. +    """ + +    attribute_mocktype = unittest.mock.MagicMock + +    def __init__(self, **kwargs) -> None: +        super().__init__(spec=context_instance, **kwargs) +        self.bot = MockBot() +        self.guild = MockGuild() +        self.author = MockMember() +        self.command = unittest.mock.MagicMock() + +        # `discord.ext.commands.Context` coroutines +        self.fetch_message = AsyncMock() +        self.invoke = AsyncMock() +        self.pins = AsyncMock() +        self.reinvoke = AsyncMock() +        self.send = AsyncMock() +        self.send_help = AsyncMock() +        self.trigger_typing = AsyncMock() + + +# Create a TextChannel instance to get a realistic MagicMock of `discord.TextChannel` +channel_data = { +    'id': 1, +    'type': 'TextChannel', +    'name': 'channel', +    'parent_id': 1234567890, +    'topic': 'topic', +    'position': 1, +    'nsfw': False, +    'last_message_id': 1, +} +state = unittest.mock.MagicMock() +guild = unittest.mock.MagicMock() +channel_instance = discord.TextChannel(state=state, guild=guild, data=channel_data) + + +class MockTextChannel(AttributeMock, unittest.mock.Mock, HashableMixin): +    """ +    A MagicMock subclass to mock TextChannel objects. + +    Instances of this class will follow the specifications of `discord.TextChannel` instances. For +    more information, see the `MockGuild` docstring. +    """ + +    attribute_mocktype = unittest.mock.MagicMock + +    def __init__(self, name: str = 'channel', channel_id: int = 1, **kwargs) -> None: +        super().__init__(spec=channel_instance, **kwargs) +        self.id = channel_id +        self.name = name +        self.guild = MockGuild() +        self.mention = f"#{self.name}" + +        # `discord.TextChannel` coroutines +        self.clone = AsyncMock() +        self.create_invite = AsyncMock() +        self.create_webhook = AsyncMock() +        self.delete = AsyncMock() +        self.delete_messages = AsyncMock() +        self.edit = AsyncMock() +        self.fetch_message = AsyncMock() +        self.invites = AsyncMock() +        self.pins = AsyncMock() +        self.purge = AsyncMock() +        self.send = AsyncMock() +        self.set_permissions = AsyncMock() +        self.trigger_typing = AsyncMock() +        self.webhooks = AsyncMock() + + +# Create a Message instance to get a realistic MagicMock of `discord.Message` +message_data = { +    'id': 1, +    'webhook_id': 431341013479718912, +    'attachments': [], +    'embeds': [], +    'application': 'Python Discord', +    'activity': 'mocking', +    'channel': unittest.mock.MagicMock(), +    'edited_timestamp': '2019-10-14T15:33:48+00:00', +    'type': 'message', +    'pinned': False, +    'mention_everyone': False, +    'tts': None, +    'content': 'content', +    'nonce': None, +} +state = unittest.mock.MagicMock() +channel = unittest.mock.MagicMock() +message_instance = discord.Message(state=state, channel=channel, data=message_data) + + +class MockMessage(AttributeMock, unittest.mock.MagicMock): +    """ +    A MagicMock subclass to mock Message objects. + +    Instances of this class will follow the specifications of `discord.Message` instances. For more +    information, see the `MockGuild` docstring. +    """ + +    attribute_mocktype = unittest.mock.MagicMock + +    def __init__(self, **kwargs) -> None: +        super().__init__(spec=message_instance, **kwargs) +        self.author = MockMember() +        self.channel = MockTextChannel() + +        # `discord.Message` coroutines +        self.ack = AsyncMock() +        self.add_reaction = AsyncMock() +        self.clear_reactions = AsyncMock() +        self.delete = AsyncMock() +        self.edit = AsyncMock() +        self.pin = AsyncMock() +        self.remove_reaction = AsyncMock() +        self.unpin = AsyncMock() diff --git a/tests/rules/test_attachments.py b/tests/rules/test_attachments.py deleted file mode 100644 index 6f025b3cb..000000000 --- a/tests/rules/test_attachments.py +++ /dev/null @@ -1,52 +0,0 @@ -import asyncio -from dataclasses import dataclass -from typing import Any, List - -import pytest - -from bot.rules import attachments - - -# Using `MagicMock` sadly doesn't work for this usecase -# since it's __eq__ compares the MagicMock's ID. We just -# want to compare the actual attributes we set. -@dataclass -class FakeMessage: -    author: str -    attachments: List[Any] - - -def msg(total_attachments: int): -    return FakeMessage(author='lemon', attachments=list(range(total_attachments))) - - -    'messages', -    ( -        (msg(0), msg(0), msg(0)), -        (msg(2), msg(2)), -        (msg(0),), -    ) -) -def test_allows_messages_without_too_many_attachments(messages): -    last_message, *recent_messages = messages -    coro = attachments.apply(last_message, recent_messages, {'max': 5}) -    assert asyncio.run(coro) is None - - -    ('messages', 'relevant_messages', 'total'), -    ( -        ((msg(4), msg(0), msg(6)), [msg(4), msg(6)], 10), -        ((msg(6),), [msg(6)], 6), -        ((msg(1),) * 6, [msg(1)] * 6, 6), -    ) -) -def test_disallows_messages_with_too_many_attachments(messages, relevant_messages, total): -    last_message, *recent_messages = messages -    coro = attachments.apply(last_message, recent_messages, {'max': 5}) -    assert asyncio.run(coro) == ( -        f"sent {total} attachments in 5s", -        ('lemon',), -        relevant_messages -    ) diff --git a/tests/test_api.py b/tests/test_api.py deleted file mode 100644 index ce69ef187..000000000 --- a/tests/test_api.py +++ /dev/null @@ -1,106 +0,0 @@ -import logging -from unittest.mock import MagicMock, patch - -import pytest - -from bot import api -from tests.helpers import async_test - - -def test_loop_is_not_running_by_default(): -    assert not api.loop_is_running() - - -@async_test -async def test_loop_is_running_in_async_test(): -    assert api.loop_is_running() - - -def error_api_response(): -    response = MagicMock() -    response.status = 999 -    return response - - -def api_log_handler(): -    return api.APILoggingHandler(None) - - -def debug_log_record(): -    return logging.LogRecord( -        name='my.logger', level=logging.DEBUG, -        pathname='my/logger.py', lineno=666, -        msg="Lemon wins", args=(), -        exc_info=None -    ) - - -def test_response_code_error_default_initialization(error_api_response): -    error = api.ResponseCodeError(response=error_api_response) -    assert error.status is error_api_response.status -    assert not error.response_json -    assert not error.response_text -    assert error.response is error_api_response - - -def test_response_code_error_default_representation(error_api_response): -    error = api.ResponseCodeError(response=error_api_response) -    assert str(error) == f"Status: {error_api_response.status} Response: " - - -def test_response_code_error_representation_with_nonempty_response_json(error_api_response): -    error = api.ResponseCodeError( -        response=error_api_response, -        response_json={'hello': 'world'} -    ) -    assert str(error) == f"Status: {error_api_response.status} Response: {{'hello': 'world'}}" - - -def test_response_code_error_representation_with_nonempty_response_text(error_api_response): -    error = api.ResponseCodeError( -        response=error_api_response, -        response_text='Lemon will eat your soul' -    ) -    assert str(error) == f"Status: {error_api_response.status} Response: Lemon will eat your soul" - - -@patch('bot.api.APILoggingHandler.ship_off') -def test_emit_appends_to_queue_with_stopped_event_loop( -    ship_off_patch, api_log_handler, debug_log_record -): -    # This is a coroutine so returns something we should await, -    # but asyncio complains about that. To ease testing, we patch -    # `ship_off` to just return a regular value instead. -    ship_off_patch.return_value = 42 -    api_log_handler.emit(debug_log_record) - -    assert api_log_handler.queue == [42] - - -def test_emit_ignores_less_than_debug(debug_log_record, api_log_handler): -    debug_log_record.levelno = logging.DEBUG - 5 -    api_log_handler.emit(debug_log_record) -    assert not api_log_handler.queue - - -def test_schedule_queued_tasks_for_empty_queue(api_log_handler, caplog): -    api_log_handler.schedule_queued_tasks() -    # Logs when tasks are scheduled -    assert not caplog.records - - -@patch('asyncio.create_task') -def test_schedule_queued_tasks_for_nonempty_queue(create_task_patch, api_log_handler, caplog): -    api_log_handler.queue = [555] -    api_log_handler.schedule_queued_tasks() -    assert not api_log_handler.queue -    create_task_patch.assert_called_once_with(555) - -    [record] = caplog.records -    assert record.message == "Scheduled 1 pending logging tasks." -    assert record.levelno == logging.DEBUG -    assert record.name == 'bot.api' -    assert record.__dict__['via_handler'] diff --git a/tests/test_base.py b/tests/test_base.py new file mode 100644 index 000000000..a16e2af8f --- /dev/null +++ b/tests/test_base.py @@ -0,0 +1,91 @@ +import logging +import unittest +import unittest.mock + + +from tests.base import LoggingTestCase, _CaptureLogHandler + + +class LoggingTestCaseTests(unittest.TestCase): +    """Tests for the LoggingTestCase.""" + +    @classmethod +    def setUpClass(cls): +        cls.log = logging.getLogger(__name__) + +    def test_assert_not_logs_does_not_raise_with_no_logs(self): +        """Test if LoggingTestCase.assertNotLogs does not raise when no logs were emitted.""" +        try: +            with LoggingTestCase.assertNotLogs(self, level=logging.DEBUG): +                pass +        except AssertionError: +            self.fail("`self.assertNotLogs` raised an AssertionError when it should not!") + +    @unittest.mock.patch("tests.base.LoggingTestCase.assertNotLogs") +    def test_the_test_function_assert_not_logs_does_not_raise_with_no_logs(self, assertNotLogs): +        """Test if test_assert_not_logs_does_not_raise_with_no_logs captures exception correctly.""" +        assertNotLogs.return_value = iter([None]) +        assertNotLogs.side_effect = AssertionError + +        message = "`self.assertNotLogs` raised an AssertionError when it should not!" +        with self.assertRaises(AssertionError, msg=message): +            self.test_assert_not_logs_does_not_raise_with_no_logs() + +    def test_assert_not_logs_raises_correct_assertion_error_when_logs_are_emitted(self): +        """Test if LoggingTestCase.assertNotLogs raises AssertionError when logs were emitted.""" +        msg_regex = ( +            r"1 logs of DEBUG or higher were triggered on root:\n" +            r'<LogRecord: tests\.test_base, [\d]+, .+/tests/test_base\.py, [\d]+, "Log!">' +        ) +        with self.assertRaisesRegex(AssertionError, msg_regex): +            with LoggingTestCase.assertNotLogs(self, level=logging.DEBUG): +                self.log.debug("Log!") + +    def test_assert_not_logs_reraises_unexpected_exception_in_managed_context(self): +        """Test if LoggingTestCase.assertNotLogs reraises an unexpected exception.""" +        with self.assertRaises(ValueError, msg="test exception"): +            with LoggingTestCase.assertNotLogs(self, level=logging.DEBUG): +                raise ValueError("test exception") + +    def test_assert_not_logs_restores_old_logging_settings(self): +        """Test if LoggingTestCase.assertNotLogs reraises an unexpected exception.""" +        old_handlers = self.log.handlers[:] +        old_level = self.log.level +        old_propagate = self.log.propagate + +        with LoggingTestCase.assertNotLogs(self, level=logging.DEBUG): +            pass + +        self.assertEqual(self.log.handlers, old_handlers) +        self.assertEqual(self.log.level, old_level) +        self.assertEqual(self.log.propagate, old_propagate) + +    def test_logging_test_case_works_with_logger_instance(self): +        """Test if the LoggingTestCase captures logging for provided logger.""" +        log = logging.getLogger("new_logger") +        with self.assertRaises(AssertionError): +            with LoggingTestCase.assertNotLogs(self, logger=log): +                log.info("Hello, this should raise an AssertionError") + +    def test_logging_test_case_respects_alternative_logger(self): +        """Test if LoggingTestCase only checks the provided logger.""" +        log_one = logging.getLogger("log one") +        log_two = logging.getLogger("log two") +        with LoggingTestCase.assertNotLogs(self, logger=log_one): +            log_two.info("Hello, this should not raise an AssertionError") + +    def test_logging_test_case_respects_logging_level(self): +        """Test if LoggingTestCase does not raise for a logging level lower than provided.""" +        with LoggingTestCase.assertNotLogs(self, level=logging.CRITICAL): +            self.log.info("Hello, this should raise an AssertionError") + +    def test_capture_log_handler_default_initialization(self): +        """Test if the _CaptureLogHandler is initialized properly.""" +        handler = _CaptureLogHandler() +        self.assertFalse(handler.records) + +    def test_capture_log_handler_saves_record_on_emit(self): +        """Test if the _CaptureLogHandler saves the log record when it's emitted.""" +        handler = _CaptureLogHandler() +        handler.emit("Log message") +        self.assertIn("Log message", handler.records) diff --git a/tests/test_constants.py b/tests/test_constants.py deleted file mode 100644 index e4a29d994..000000000 --- a/tests/test_constants.py +++ /dev/null @@ -1,23 +0,0 @@ -import inspect - -import pytest - -from bot import constants - - -    'section', -    ( -        cls -        for (name, cls) in inspect.getmembers(constants) -        if hasattr(cls, 'section') and isinstance(cls, type) -    ) -) -def test_section_configuration_matches_typespec(section): -    for (name, annotation) in section.__annotations__.items(): -        value = getattr(section, name) - -        if getattr(annotation, '_name', None) in ('Dict', 'List'): -            pytest.skip("Cannot validate containers yet") - -        assert isinstance(value, annotation) diff --git a/tests/test_converters.py b/tests/test_converters.py deleted file mode 100644 index f69995ec6..000000000 --- a/tests/test_converters.py +++ /dev/null @@ -1,264 +0,0 @@ -import asyncio -import datetime -from unittest.mock import MagicMock, patch - -import pytest -from dateutil.relativedelta import relativedelta -from discord.ext.commands import BadArgument - -from bot.converters import ( -    Duration, -    ISODateTime, -    TagContentConverter, -    TagNameConverter, -    ValidPythonIdentifier, -) - - -    ('value', 'expected'), -    ( -        ('hello', 'hello'), -        ('  h ello  ', 'h ello') -    ) -) -def test_tag_content_converter_for_valid(value: str, expected: str): -    assert asyncio.run(TagContentConverter.convert(None, value)) == expected - - -    ('value', 'expected'), -    ( -        ('', "Tag contents should not be empty, or filled with whitespace."), -        ('   ', "Tag contents should not be empty, or filled with whitespace.") -    ) -) -def test_tag_content_converter_for_invalid(value: str, expected: str): -    context = MagicMock() -    context.author = 'bob' - -    with pytest.raises(BadArgument, match=expected): -        asyncio.run(TagContentConverter.convert(context, value)) - - -    ('value', 'expected'), -    ( -        ('tracebacks', 'tracebacks'), -        ('Tracebacks', 'tracebacks'), -        ('  Tracebacks  ', 'tracebacks'), -    ) -) -def test_tag_name_converter_for_valid(value: str, expected: str): -    assert asyncio.run(TagNameConverter.convert(None, value)) == expected - - -    ('value', 'expected'), -    ( -        ('👋', "Don't be ridiculous, you can't use that character!"), -        ('', "Tag names should not be empty, or filled with whitespace."), -        ('  ', "Tag names should not be empty, or filled with whitespace."), -        ('42', "Tag names can't be numbers."), -        # Escape question mark as this is evaluated as regular expression. -        ('x' * 128, r"Are you insane\? That's way too long!"), -    ) -) -def test_tag_name_converter_for_invalid(value: str, expected: str): -    context = MagicMock() -    context.author = 'bob' - -    with pytest.raises(BadArgument, match=expected): -        asyncio.run(TagNameConverter.convert(context, value)) - - [email protected]('value', ('foo', 'lemon')) -def test_valid_python_identifier_for_valid(value: str): -    assert asyncio.run(ValidPythonIdentifier.convert(None, value)) == value - - [email protected]('value', ('nested.stuff', '#####')) -def test_valid_python_identifier_for_invalid(value: str): -    with pytest.raises(BadArgument, match=f'`{value}` is not a valid Python identifier'): -        asyncio.run(ValidPythonIdentifier.convert(None, value)) - - -FIXED_UTC_NOW = datetime.datetime.fromisoformat('2019-01-01T00:00:00') - - -    params=( -        # Simple duration strings -        ('1Y', {"years": 1}), -        ('1y', {"years": 1}), -        ('1year', {"years": 1}), -        ('1years', {"years": 1}), -        ('1m', {"months": 1}), -        ('1month', {"months": 1}), -        ('1months', {"months": 1}), -        ('1w', {"weeks": 1}), -        ('1W', {"weeks": 1}), -        ('1week', {"weeks": 1}), -        ('1weeks', {"weeks": 1}), -        ('1d', {"days": 1}), -        ('1D', {"days": 1}), -        ('1day', {"days": 1}), -        ('1days', {"days": 1}), -        ('1h', {"hours": 1}), -        ('1H', {"hours": 1}), -        ('1hour', {"hours": 1}), -        ('1hours', {"hours": 1}), -        ('1M', {"minutes": 1}), -        ('1minute', {"minutes": 1}), -        ('1minutes', {"minutes": 1}), -        ('1s', {"seconds": 1}), -        ('1S', {"seconds": 1}), -        ('1second', {"seconds": 1}), -        ('1seconds', {"seconds": 1}), - -        # Complex duration strings -        ( -            '1y1m1w1d1H1M1S', -            { -                "years": 1, -                "months": 1, -                "weeks": 1, -                "days": 1, -                "hours": 1, -                "minutes": 1, -                "seconds": 1 -            } -        ), -        ('5y100S', {"years": 5, "seconds": 100}), -        ('2w28H', {"weeks": 2, "hours": 28}), - -        # Duration strings with spaces -        ('1 year 2 months', {"years": 1, "months": 2}), -        ('1d 2H', {"days": 1, "hours": 2}), -        ('1 week2 days', {"weeks": 1, "days": 2}), -    ) -) -def create_future_datetime(request): -    """Yields duration string and target datetime.datetime object.""" -    duration, duration_dict = request.param -    future_datetime = FIXED_UTC_NOW + relativedelta(**duration_dict) -    yield duration, future_datetime - - -def test_duration_converter_for_valid(create_future_datetime: tuple): -    converter = Duration() -    duration, expected = create_future_datetime -    with patch('bot.converters.datetime') as mock_datetime: -        mock_datetime.utcnow.return_value = FIXED_UTC_NOW -        assert asyncio.run(converter.convert(None, duration)) == expected - - -    ('duration'), -    ( -        # Units in wrong order -        ('1d1w'), -        ('1s1y'), - -        # Duplicated units -        ('1 year 2 years'), -        ('1 M 10 minutes'), - -        # Unknown substrings -        ('1MVes'), -        ('1y3breads'), - -        # Missing amount -        ('ym'), - -        # Incorrect whitespace -        (" 1y"), -        ("1S "), -        ("1y  1m"), - -        # Garbage -        ('Guido van Rossum'), -        ('lemon lemon lemon lemon lemon lemon lemon'), -    ) -) -def test_duration_converter_for_invalid(duration: str): -    converter = Duration() -    with pytest.raises(BadArgument, match=f'`{duration}` is not a valid duration string.'): -        asyncio.run(converter.convert(None, duration)) - - -    ("datetime_string", "expected_dt"), -    ( - -        # `YYYY-mm-ddTHH:MM:SSZ` | `YYYY-mm-dd HH:MM:SSZ` -        ('2019-09-02T02:03:05Z', datetime.datetime(2019, 9, 2, 2, 3, 5)), -        ('2019-09-02 02:03:05Z', datetime.datetime(2019, 9, 2, 2, 3, 5)), - -        # `YYYY-mm-ddTHH:MM:SS±HH:MM` | `YYYY-mm-dd HH:MM:SS±HH:MM` -        ('2019-09-02T03:18:05+01:15', datetime.datetime(2019, 9, 2, 2, 3, 5)), -        ('2019-09-02 03:18:05+01:15', datetime.datetime(2019, 9, 2, 2, 3, 5)), -        ('2019-09-02T00:48:05-01:15', datetime.datetime(2019, 9, 2, 2, 3, 5)), -        ('2019-09-02 00:48:05-01:15', datetime.datetime(2019, 9, 2, 2, 3, 5)), - -        # `YYYY-mm-ddTHH:MM:SS±HHMM` | `YYYY-mm-dd HH:MM:SS±HHMM` -        ('2019-09-02T03:18:05+0115', datetime.datetime(2019, 9, 2, 2, 3, 5)), -        ('2019-09-02 03:18:05+0115', datetime.datetime(2019, 9, 2, 2, 3, 5)), -        ('2019-09-02T00:48:05-0115', datetime.datetime(2019, 9, 2, 2, 3, 5)), -        ('2019-09-02 00:48:05-0115', datetime.datetime(2019, 9, 2, 2, 3, 5)), - -        # `YYYY-mm-ddTHH:MM:SS±HH` | `YYYY-mm-dd HH:MM:SS±HH` -        ('2019-09-02 03:03:05+01', datetime.datetime(2019, 9, 2, 2, 3, 5)), -        ('2019-09-02T01:03:05-01', datetime.datetime(2019, 9, 2, 2, 3, 5)), - -        # `YYYY-mm-ddTHH:MM:SS` | `YYYY-mm-dd HH:MM:SS` -        ('2019-09-02T02:03:05', datetime.datetime(2019, 9, 2, 2, 3, 5)), -        ('2019-09-02 02:03:05', datetime.datetime(2019, 9, 2, 2, 3, 5)), - -        # `YYYY-mm-ddTHH:MM` | `YYYY-mm-dd HH:MM` -        ('2019-11-12T09:15', datetime.datetime(2019, 11, 12, 9, 15)), -        ('2019-11-12 09:15', datetime.datetime(2019, 11, 12, 9, 15)), - -        # `YYYY-mm-dd` -        ('2019-04-01', datetime.datetime(2019, 4, 1)), - -        # `YYYY-mm` -        ('2019-02-01', datetime.datetime(2019, 2, 1)), - -        # `YYYY` -        ('2025', datetime.datetime(2025, 1, 1)), -    ), -) -def test_isodatetime_converter_for_valid(datetime_string: str, expected_dt: datetime.datetime): -    converter = ISODateTime() -    converted_dt = asyncio.run(converter.convert(None, datetime_string)) -    assert converted_dt.tzinfo is None -    assert converted_dt == expected_dt - - -    ("datetime_string"), -    ( -        # Make sure it doesn't interfere with the Duration converter -        ('1Y'), -        ('1d'), -        ('1H'), - -        # Check if it fails when only providing the optional time part -        ('10:10:10'), -        ('10:00'), - -        # Invalid date format -        ('19-01-01'), - -        # Other non-valid strings -        ('fisk the tag master'), -    ), -) -def test_isodatetime_converter_for_invalid(datetime_string: str): -    converter = ISODateTime() -    with pytest.raises( -        BadArgument, -        match=f"`{datetime_string}` is not a valid ISO-8601 datetime string", -    ): -        asyncio.run(converter.convert(None, datetime_string)) diff --git a/tests/test_helpers.py b/tests/test_helpers.py new file mode 100644 index 000000000..f08239981 --- /dev/null +++ b/tests/test_helpers.py @@ -0,0 +1,366 @@ +import asyncio +import inspect +import unittest +import unittest.mock + +import discord + +from tests import helpers + + +class DiscordMocksTests(unittest.TestCase): +    """Tests for our specialized discord.py mocks.""" + +    def test_mock_role_default_initialization(self): +        """Test if the default initialization of MockRole results in the correct object.""" +        role = helpers.MockRole() + +        # The `spec` argument makes sure `isistance` checks with `discord.Role` pass +        self.assertIsInstance(role, discord.Role) + +        self.assertEqual(role.name, "role") +        self.assertEqual(role.id, 1) +        self.assertEqual(role.position, 1) +        self.assertEqual(role.mention, "&role") + +    def test_mock_role_alternative_arguments(self): +        """Test if MockRole initializes with the arguments provided.""" +        role = helpers.MockRole( +            name="Admins", +            role_id=90210, +            position=10, +        ) + +        self.assertEqual(role.name, "Admins") +        self.assertEqual(role.id, 90210) +        self.assertEqual(role.position, 10) +        self.assertEqual(role.mention, "&Admins") + +    def test_mock_role_accepts_dynamic_arguments(self): +        """Test if MockRole accepts and sets abitrary keyword arguments.""" +        role = helpers.MockRole( +            guild="Dino Man", +            hoist=True, +        ) + +        self.assertEqual(role.guild, "Dino Man") +        self.assertTrue(role.hoist) + +    def test_mock_role_uses_position_for_less_than_greater_than(self): +        """Test if `<` and `>` comparisons for MockRole are based on its position attribute.""" +        role_one = helpers.MockRole(position=1) +        role_two = helpers.MockRole(position=2) +        role_three = helpers.MockRole(position=3) + +        self.assertLess(role_one, role_two) +        self.assertLess(role_one, role_three) +        self.assertLess(role_two, role_three) +        self.assertGreater(role_three, role_two) +        self.assertGreater(role_three, role_one) +        self.assertGreater(role_two, role_one) + +    def test_mock_member_default_initialization(self): +        """Test if the default initialization of Mockmember results in the correct object.""" +        member = helpers.MockMember() + +        # The `spec` argument makes sure `isistance` checks with `discord.Member` pass +        self.assertIsInstance(member, discord.Member) + +        self.assertEqual(member.name, "member") +        self.assertEqual(member.id, 1) +        self.assertListEqual(member.roles, [helpers.MockRole("@everyone", 1)]) +        self.assertEqual(member.mention, "@member") + +    def test_mock_member_alternative_arguments(self): +        """Test if MockMember initializes with the arguments provided.""" +        core_developer = helpers.MockRole("Core Developer", 2) +        member = helpers.MockMember( +            name="Mark", +            user_id=12345, +            roles=[core_developer] +        ) + +        self.assertEqual(member.name, "Mark") +        self.assertEqual(member.id, 12345) +        self.assertListEqual(member.roles, [helpers.MockRole("@everyone", 1), core_developer]) +        self.assertEqual(member.mention, "@Mark") + +    def test_mock_member_accepts_dynamic_arguments(self): +        """Test if MockMember accepts and sets abitrary keyword arguments.""" +        member = helpers.MockMember( +            nick="Dino Man", +            colour=discord.Colour.default(), +        ) + +        self.assertEqual(member.nick, "Dino Man") +        self.assertEqual(member.colour, discord.Colour.default()) + +    def test_mock_guild_default_initialization(self): +        """Test if the default initialization of Mockguild results in the correct object.""" +        guild = helpers.MockGuild() + +        # The `spec` argument makes sure `isistance` checks with `discord.Guild` pass +        self.assertIsInstance(guild, discord.Guild) + +        self.assertListEqual(guild.roles, [helpers.MockRole("@everyone", 1)]) +        self.assertListEqual(guild.members, []) + +    def test_mock_guild_alternative_arguments(self): +        """Test if MockGuild initializes with the arguments provided.""" +        core_developer = helpers.MockRole("Core Developer", 2) +        guild = helpers.MockGuild( +            roles=[core_developer], +            members=[helpers.MockMember(user_id=54321)], +        ) + +        self.assertListEqual(guild.roles, [helpers.MockRole("@everyone", 1), core_developer]) +        self.assertListEqual(guild.members, [helpers.MockMember(user_id=54321)]) + +    def test_mock_guild_accepts_dynamic_arguments(self): +        """Test if MockGuild accepts and sets abitrary keyword arguments.""" +        guild = helpers.MockGuild( +            emojis=(":hyperjoseph:", ":pensive_ela:"), +            premium_subscription_count=15, +        ) + +        self.assertTupleEqual(guild.emojis, (":hyperjoseph:", ":pensive_ela:")) +        self.assertEqual(guild.premium_subscription_count, 15) + +    def test_mock_bot_default_initialization(self): +        """Tests if MockBot initializes with the correct values.""" +        bot = helpers.MockBot() + +        # The `spec` argument makes sure `isistance` checks with `discord.ext.commands.Bot` pass +        self.assertIsInstance(bot, discord.ext.commands.Bot) + +    def test_mock_context_default_initialization(self): +        """Tests if MockContext initializes with the correct values.""" +        context = helpers.MockContext() + +        # The `spec` argument makes sure `isistance` checks with `discord.ext.commands.Context` pass +        self.assertIsInstance(context, discord.ext.commands.Context) + +        self.assertIsInstance(context.bot, helpers.MockBot) +        self.assertIsInstance(context.guild, helpers.MockGuild) +        self.assertIsInstance(context.author, helpers.MockMember) + +    def test_mocks_allows_access_to_attributes_part_of_spec(self): +        """Accessing attributes that are valid for the objects they mock should succeed.""" +        mocks = ( +            (helpers.MockGuild(), 'name'), +            (helpers.MockRole(), 'hoist'), +            (helpers.MockMember(), 'display_name'), +            (helpers.MockBot(), 'user'), +            (helpers.MockContext(), 'invoked_with'), +            (helpers.MockTextChannel(), 'last_message'), +            (helpers.MockMessage(), 'mention_everyone'), +        ) + +        for mock, valid_attribute in mocks: +            with self.subTest(mock=mock): +                try: +                    getattr(mock, valid_attribute) +                except AttributeError: +                    msg = f"accessing valid attribute `{valid_attribute}` raised an AttributeError" +                    self.fail(msg) + +    @unittest.mock.patch(f'{__name__}.DiscordMocksTests.subTest') +    @unittest.mock.patch(f'{__name__}.getattr') +    def test_mock_allows_access_to_attributes_test(self, mock_getattr, mock_subtest): +        """The valid attribute test should raise an AssertionError after an AttributeError.""" +        mock_getattr.side_effect = AttributeError + +        msg = "accessing valid attribute `name` raised an AttributeError" +        with self.assertRaises(AssertionError, msg=msg): +            self.test_mocks_allows_access_to_attributes_part_of_spec() + +    def test_mocks_rejects_access_to_attributes_not_part_of_spec(self): +        """Accessing attributes that are invalid for the objects they mock should fail.""" +        mocks = ( +            helpers.MockGuild(), +            helpers.MockRole(), +            helpers.MockMember(), +            helpers.MockBot(), +            helpers.MockContext(), +            helpers.MockTextChannel(), +            helpers.MockMessage(), +        ) + +        for mock in mocks: +            with self.subTest(mock=mock): +                with self.assertRaises(AttributeError): +                    mock.the_cake_is_a_lie + +    def test_custom_mock_methods_are_valid_discord_object_methods(self): +        """The `AsyncMock` attributes of the mocks should be valid for the class they're mocking.""" +        mocks = ( +            (helpers.MockGuild, helpers.guild_instance), +            (helpers.MockRole, helpers.role_instance), +            (helpers.MockMember, helpers.member_instance), +            (helpers.MockBot, helpers.bot_instance), +            (helpers.MockContext, helpers.context_instance), +            (helpers.MockTextChannel, helpers.channel_instance), +            (helpers.MockMessage, helpers.message_instance), +        ) + +        for mock_class, instance in mocks: +            mock = mock_class() +            async_methods = ( +                attr for attr in dir(mock) if isinstance(getattr(mock, attr), helpers.AsyncMock) +            ) + +            # spec_mock = unittest.mock.MagicMock(spec=instance) +            for method in async_methods: +                with self.subTest(mock_class=mock_class, method=method): +                    try: +                        getattr(instance, method) +                    except AttributeError: +                        msg = f"method {method} is not a method attribute of {instance.__class__}" +                        self.fail(msg) + +    @unittest.mock.patch(f'{__name__}.DiscordMocksTests.subTest') +    def test_the_custom_mock_methods_test(self, subtest_mock): +        """The custom method test should raise AssertionError for invalid methods.""" +        class FakeMockBot(helpers.AttributeMock, unittest.mock.MagicMock): +            """Fake MockBot class with invalid attribute/method `release_the_walrus`.""" + +            attribute_mocktype = unittest.mock.MagicMock + +            def __init__(self, **kwargs): +                super().__init__(spec=helpers.bot_instance, **kwargs) + +                # Fake attribute +                self.release_the_walrus = helpers.AsyncMock() + +        with unittest.mock.patch("tests.helpers.MockBot", new=FakeMockBot): +            msg = "method release_the_walrus is not a valid method of <class 'discord.ext.commands.bot.Bot'>" +            with self.assertRaises(AssertionError, msg=msg): +                self.test_custom_mock_methods_are_valid_discord_object_methods() + + +class MockObjectTests(unittest.TestCase): +    """Tests the mock objects and mixins we've defined.""" + +    @classmethod +    def setUpClass(cls): +        cls.hashable_mocks = (helpers.MockRole, helpers.MockMember, helpers.MockGuild) + +    def test_colour_mixin(self): +        """Test if the ColourMixin adds aliasing of color -> colour for child classes.""" +        class MockHemlock(unittest.mock.MagicMock, helpers.ColourMixin): +            pass + +        hemlock = MockHemlock() +        hemlock.color = 1 +        self.assertEqual(hemlock.colour, 1) +        self.assertEqual(hemlock.colour, hemlock.color) + +    def test_hashable_mixin_hash_returns_id(self): +        """Test if the HashableMixing uses the id attribute for hashing.""" +        class MockScragly(unittest.mock.Mock, helpers.HashableMixin): +            pass + +        scragly = MockScragly() +        scragly.id = 10 +        self.assertEqual(hash(scragly), scragly.id) + +    def test_hashable_mixin_uses_id_for_equality_comparison(self): +        """Test if the HashableMixing uses the id attribute for hashing.""" +        class MockScragly(unittest.mock.Mock, helpers.HashableMixin): +            pass + +        scragly = MockScragly(spec=object) +        scragly.id = 10 +        eevee = MockScragly(spec=object) +        eevee.id = 10 +        python = MockScragly(spec=object) +        python.id = 20 + +        self.assertTrue(scragly == eevee) +        self.assertFalse(scragly == python) + +    def test_hashable_mixin_uses_id_for_nonequality_comparison(self): +        """Test if the HashableMixing uses the id attribute for hashing.""" +        class MockScragly(unittest.mock.Mock, helpers.HashableMixin): +            pass + +        scragly = MockScragly(spec=object) +        scragly.id = 10 +        eevee = MockScragly(spec=object) +        eevee.id = 10 +        python = MockScragly(spec=object) +        python.id = 20 + +        self.assertTrue(scragly != python) +        self.assertFalse(scragly != eevee) + +    def test_mock_class_with_hashable_mixin_uses_id_for_hashing(self): +        """Test if the MagicMock subclasses that implement the HashableMixin use id for hash.""" +        for mock in self.hashable_mocks: +            with self.subTest(mock_class=mock): +                instance = helpers.MockRole(role_id=100) +                self.assertEqual(hash(instance), instance.id) + +    def test_mock_class_with_hashable_mixin_uses_id_for_equality(self): +        """Test if MagicMocks that implement the HashableMixin use id for equality comparisons.""" +        for mock_class in self.hashable_mocks: +            with self.subTest(mock_class=mock_class): +                instance_one = mock_class() +                instance_two = mock_class() +                instance_three = mock_class() + +                instance_one.id = 10 +                instance_two.id = 10 +                instance_three.id = 20 + +                self.assertTrue(instance_one == instance_two) +                self.assertFalse(instance_one == instance_three) + +    def test_mock_class_with_hashable_mixin_uses_id_for_nonequality(self): +        """Test if MagicMocks that implement HashableMixin use id for nonequality comparisons.""" +        for mock_class in self.hashable_mocks: +            with self.subTest(mock_class=mock_class): +                instance_one = mock_class() +                instance_two = mock_class() +                instance_three = mock_class() + +                instance_one.id = 10 +                instance_two.id = 10 +                instance_three.id = 20 + +                self.assertFalse(instance_one != instance_two) +                self.assertTrue(instance_one != instance_three) + +    def test_spec_propagation_of_mock_subclasses(self): +        """Test if the `spec` does not propagate to attributes of the mock object.""" +        test_values = ( +            (helpers.MockGuild, "region"), +            (helpers.MockRole, "mentionable"), +            (helpers.MockMember, "display_name"), +            (helpers.MockBot, "owner_id"), +            (helpers.MockContext, "command_failed"), +        ) + +        for mock_type, valid_attribute in test_values: +            with self.subTest(mock_type=mock_type, attribute=valid_attribute): +                mock = mock_type() +                self.assertTrue(isinstance(mock, mock_type)) +                attribute = getattr(mock, valid_attribute) +                self.assertTrue(isinstance(attribute, mock_type.attribute_mocktype)) + +    def test_async_mock_provides_coroutine_for_dunder_call(self): +        """Test if AsyncMock objects have a coroutine for their __call__ method.""" +        async_mock = helpers.AsyncMock() +        self.assertTrue(inspect.iscoroutinefunction(async_mock.__call__)) + +        coroutine = async_mock() +        self.assertTrue(inspect.iscoroutine(coroutine)) +        self.assertIsNotNone(asyncio.run(coroutine)) + +    def test_async_test_decorator_allows_synchronous_call_to_async_def(self): +        """Test if the `async_test` decorator allows an `async def` to be called synchronously.""" +        @helpers.async_test +        async def kosayoda(): +            return "return value" + +        self.assertEqual(kosayoda(), "return value") diff --git a/tests/test_resources.py b/tests/test_resources.py deleted file mode 100644 index bcf124f05..000000000 --- a/tests/test_resources.py +++ /dev/null @@ -1,13 +0,0 @@ -import json -from pathlib import Path - - -def test_stars_valid(): -    """Validates that `bot/resources/stars.json` contains a list of strings.""" - -    path = Path('bot', 'resources', 'stars.json') -    content = path.read_text() -    data = json.loads(content) - -    for name in data: -        assert type(name) is str diff --git a/tests/utils/test_checks.py b/tests/utils/test_checks.py deleted file mode 100644 index 7121acebd..000000000 --- a/tests/utils/test_checks.py +++ /dev/null @@ -1,66 +0,0 @@ -from unittest.mock import MagicMock - -import pytest - -from bot.utils import checks - - -def context(): -    return MagicMock() - - -def test_with_role_check_without_guild(context): -    context.guild = None - -    assert not checks.with_role_check(context) - - -def test_with_role_check_with_guild_without_required_role(context): -    context.guild = True -    context.author.roles = [] - -    assert not checks.with_role_check(context) - - -def test_with_role_check_with_guild_with_required_role(context): -    context.guild = True -    role = MagicMock() -    role.id = 42 -    context.author.roles = (role,) - -    assert checks.with_role_check(context, role.id) - - -def test_without_role_check_without_guild(context): -    context.guild = None - -    assert not checks.without_role_check(context) - - -def test_without_role_check_with_unwanted_role(context): -    context.guild = True -    role = MagicMock() -    role.id = 42 -    context.author.roles = (role,) - -    assert not checks.without_role_check(context, role.id) - - -def test_without_role_check_without_unwanted_role(context): -    context.guild = True -    role = MagicMock() -    role.id = 42 -    context.author.roles = (role,) - -    assert checks.without_role_check(context, role.id + 10) - - -def test_in_channel_check_for_correct_channel(context): -    context.channel.id = 42 -    assert checks.in_channel_check(context, context.channel.id) - - -def test_in_channel_check_for_incorrect_channel(context): -    context.channel.id = 42 -    assert not checks.in_channel_check(context, context.channel.id + 10) | 
