From 3189afa9a02fc0333b4d814a48f17122af768345 Mon Sep 17 00:00:00 2001 From: kosayoda Date: Sat, 5 Sep 2020 09:03:46 +0800 Subject: Add test for everyone_ping rule. --- tests/bot/rules/test_everyone_ping.py | 102 ++++++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 tests/bot/rules/test_everyone_ping.py (limited to 'tests') diff --git a/tests/bot/rules/test_everyone_ping.py b/tests/bot/rules/test_everyone_ping.py new file mode 100644 index 000000000..3ecc43cdc --- /dev/null +++ b/tests/bot/rules/test_everyone_ping.py @@ -0,0 +1,102 @@ +from typing import Iterable + +from bot.rules import everyone_ping +from tests.bot.rules import DisallowedCase, RuleTest +from tests.helpers import MockGuild, MockMessage + +NUM_GUILD_MEMBERS = 100 + + +def make_msg(author: str, message: str) -> MockMessage: + """Build a message with `message` as the content sent.""" + mocked_guild = MockGuild(member_count=NUM_GUILD_MEMBERS) + return MockMessage(author=author, content=message, guild=mocked_guild) + + +class EveryonePingRuleTest(RuleTest): + """Tests the `everyone_ping` antispam rule.""" + + def setUp(self): + self.apply = everyone_ping.apply + self.config = { + "max": 0, # Max allowed @everyone pings per user + "interval": 10, + } + + async def test_disallows_everyone_ping(self): + """Cases with an @everyone ping.""" + cases = ( + DisallowedCase( + [make_msg("bob", "@everyone")], + ("bob",), + 1 + ), + DisallowedCase( + [make_msg("bob", "Let me ping @everyone in the server.")], + ("bob",), + 1 + ), + DisallowedCase( + [make_msg("bob", "`codeblock message` and @everyone ping")], + ("bob",), + 1 + ), + DisallowedCase( + [make_msg("bob", "`sandwich` @everyone `ping between codeblocks`.")], + ("bob",), + 1 + ), + DisallowedCase( + [make_msg("bob", "This is a multiline\n@everyone\nping.")], + ("bob",), + 1 + ), + # Not actually valid code blocks + DisallowedCase( + [make_msg("bob", "`@everyone``")], + ("bob",), + 1 + ), + DisallowedCase( + [make_msg("bob", "`@everyone``````")], + ("bob",), + 1 + ), + DisallowedCase( + [make_msg("bob", "``@everyone``````")], + ("bob",), + 1 + ), + ) + + await self.run_disallowed(cases) + + async def test_allows_inline_codeblock_everyone_ping(self): + """Cases with an @everyone ping in an inline codeblock.""" + cases = ( + [make_msg("bob", "Codeblock has `@everyone` ping.")], + [make_msg("bob", "Multiple `codeblocks` including `@everyone` ping.")], + [make_msg("bob", "This is a valid ``inline @everyone` ping.")], + ) + + await self.run_allowed(cases) + + async def test_allows_multiline_codeblock_everyone_ping(self): + """Cases with an @everyone ping in a multiline codeblock.""" + cases = ( + [make_msg("bob", "```Multiline codeblock has\nan `@everyone` ping.```")], + [make_msg("bob", "``` `@everyone``` ` `")], + ) + + await self.run_allowed(cases) + + def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: + last_message = case.recent_messages[0] + return tuple( + msg + for msg in case.recent_messages + if msg.author == last_message.author + ) + + def get_report(self, case: DisallowedCase) -> str: + return "pinged the everyone role" -- cgit v1.2.3 From 0c6215f3f6e7247a5e13199931f733a8e203047e Mon Sep 17 00:00:00 2001 From: kosayoda Date: Mon, 7 Sep 2020 21:08:54 +0800 Subject: Remove everyone_ping rule from antispam. The feature will be moved to the filtering cog. --- bot/cogs/antispam.py | 1 - bot/rules/__init__.py | 1 - bot/rules/everyone_ping.py | 44 --------------- config-default.yml | 4 -- tests/bot/rules/test_everyone_ping.py | 102 ---------------------------------- 5 files changed, 152 deletions(-) delete mode 100644 bot/rules/everyone_ping.py delete mode 100644 tests/bot/rules/test_everyone_ping.py (limited to 'tests') diff --git a/bot/cogs/antispam.py b/bot/cogs/antispam.py index b8939113f..5c97621fb 100644 --- a/bot/cogs/antispam.py +++ b/bot/cogs/antispam.py @@ -36,7 +36,6 @@ RULE_FUNCTION_MAPPING = { 'mentions': rules.apply_mentions, 'newlines': rules.apply_newlines, 'role_mentions': rules.apply_role_mentions, - 'everyone_ping': rules.apply_everyone_ping, } diff --git a/bot/rules/__init__.py b/bot/rules/__init__.py index 8a69cadee..a01ceae73 100644 --- a/bot/rules/__init__.py +++ b/bot/rules/__init__.py @@ -10,4 +10,3 @@ from .links import apply as apply_links from .mentions import apply as apply_mentions from .newlines import apply as apply_newlines from .role_mentions import apply as apply_role_mentions -from .everyone_ping import apply as apply_everyone_ping diff --git a/bot/rules/everyone_ping.py b/bot/rules/everyone_ping.py deleted file mode 100644 index 8fc03b924..000000000 --- a/bot/rules/everyone_ping.py +++ /dev/null @@ -1,44 +0,0 @@ -import random -import re -from typing import Dict, Iterable, List, Optional, Tuple - -from discord import Embed, Member, Message - -from bot.constants import Colours, Guild, NEGATIVE_REPLIES - -# Generate regex for checking for pings: -guild_id = Guild.id -EVERYONE_PING_RE = re.compile(rf"@everyone|<@&{guild_id}>") -CODE_BLOCK_RE = re.compile( - r"(?P``?)[^`]+?(?P=delim)(?!`+)" # Inline codeblock - r"|```(.+?)```", # Multiline codeblock - re.DOTALL | re.MULTILINE -) - - -async def apply( - last_message: Message, - recent_messages: List[Message], - config: Dict[str, int], -) -> Optional[Tuple[str, Iterable[Member], Iterable[Message]]]: - """Detects if a user has sent an '@everyone' ping.""" - relevant_messages = tuple(msg for msg in recent_messages if msg.author == last_message.author) - - everyone_messages_count = 0 - for msg in relevant_messages: - content = CODE_BLOCK_RE.sub("", msg.content) # Remove codeblocks in the message - if matches := len(EVERYONE_PING_RE.findall(content)): - everyone_messages_count += matches - - if everyone_messages_count > config["max"]: - # Send the channel an embed giving the user more info: - embed_text = f"Please don't try to ping {last_message.guild.member_count:,} people." - embed = Embed(title=random.choice(NEGATIVE_REPLIES), description=embed_text, colour=Colours.soft_red) - await last_message.channel.send(embed=embed) - - return ( - "pinged the everyone role", - (last_message.author,), - relevant_messages, - ) - return None diff --git a/config-default.yml b/config-default.yml index e9324c62f..c1eef713f 100644 --- a/config-default.yml +++ b/config-default.yml @@ -389,10 +389,6 @@ anti_spam: interval: 10 max: 3 - everyone_ping: - interval: 10 - max: 0 - reddit: subreddits: diff --git a/tests/bot/rules/test_everyone_ping.py b/tests/bot/rules/test_everyone_ping.py deleted file mode 100644 index 3ecc43cdc..000000000 --- a/tests/bot/rules/test_everyone_ping.py +++ /dev/null @@ -1,102 +0,0 @@ -from typing import Iterable - -from bot.rules import everyone_ping -from tests.bot.rules import DisallowedCase, RuleTest -from tests.helpers import MockGuild, MockMessage - -NUM_GUILD_MEMBERS = 100 - - -def make_msg(author: str, message: str) -> MockMessage: - """Build a message with `message` as the content sent.""" - mocked_guild = MockGuild(member_count=NUM_GUILD_MEMBERS) - return MockMessage(author=author, content=message, guild=mocked_guild) - - -class EveryonePingRuleTest(RuleTest): - """Tests the `everyone_ping` antispam rule.""" - - def setUp(self): - self.apply = everyone_ping.apply - self.config = { - "max": 0, # Max allowed @everyone pings per user - "interval": 10, - } - - async def test_disallows_everyone_ping(self): - """Cases with an @everyone ping.""" - cases = ( - DisallowedCase( - [make_msg("bob", "@everyone")], - ("bob",), - 1 - ), - DisallowedCase( - [make_msg("bob", "Let me ping @everyone in the server.")], - ("bob",), - 1 - ), - DisallowedCase( - [make_msg("bob", "`codeblock message` and @everyone ping")], - ("bob",), - 1 - ), - DisallowedCase( - [make_msg("bob", "`sandwich` @everyone `ping between codeblocks`.")], - ("bob",), - 1 - ), - DisallowedCase( - [make_msg("bob", "This is a multiline\n@everyone\nping.")], - ("bob",), - 1 - ), - # Not actually valid code blocks - DisallowedCase( - [make_msg("bob", "`@everyone``")], - ("bob",), - 1 - ), - DisallowedCase( - [make_msg("bob", "`@everyone``````")], - ("bob",), - 1 - ), - DisallowedCase( - [make_msg("bob", "``@everyone``````")], - ("bob",), - 1 - ), - ) - - await self.run_disallowed(cases) - - async def test_allows_inline_codeblock_everyone_ping(self): - """Cases with an @everyone ping in an inline codeblock.""" - cases = ( - [make_msg("bob", "Codeblock has `@everyone` ping.")], - [make_msg("bob", "Multiple `codeblocks` including `@everyone` ping.")], - [make_msg("bob", "This is a valid ``inline @everyone` ping.")], - ) - - await self.run_allowed(cases) - - async def test_allows_multiline_codeblock_everyone_ping(self): - """Cases with an @everyone ping in a multiline codeblock.""" - cases = ( - [make_msg("bob", "```Multiline codeblock has\nan `@everyone` ping.```")], - [make_msg("bob", "``` `@everyone``` ` `")], - ) - - await self.run_allowed(cases) - - def relevant_messages(self, case: DisallowedCase) -> Iterable[MockMessage]: - last_message = case.recent_messages[0] - return tuple( - msg - for msg in case.recent_messages - if msg.author == last_message.author - ) - - def get_report(self, case: DisallowedCase) -> str: - return "pinged the everyone role" -- cgit v1.2.3 From fc058490d62fa373e6ebb01392c511027401e72f Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff Date: Sat, 19 Sep 2020 21:45:54 +0200 Subject: Use async-rediscache package for our redis caches I've migrated our redis caches over to the async-rediscache package that we've recently released (https://git.pydis.com/async-rediscache). The main functionality remains the same, although the package handles some things, like getting the active session, differently internally. The main changes you'll note for our bot are: - We create a RedisSession instance and ensure that it connects before we even create a Bot instance in `__main__.py`. - We are now actually using a connection pool instead of a single connection. - Our Bot subclass now has a new required kwarg: `redis_session`. - Bool values are now properly converted to and from typestrings. In addition, I've made sure that our MockBot passes a MagicMock for the new `redis_session` kwarg when creating a Bot instance for the spec_set. Signed-off-by: Sebastiaan Zeeff --- Pipfile | 2 +- Pipfile.lock | 243 ++++++++++++++++++++++++---------------------- bot/__main__.py | 20 ++++ bot/bot.py | 49 ++-------- bot/cogs/dm_relay.py | 2 +- bot/cogs/filtering.py | 4 +- bot/cogs/help_channels.py | 2 +- bot/cogs/verification.py | 2 +- tests/helpers.py | 6 +- 9 files changed, 167 insertions(+), 163 deletions(-) (limited to 'tests') diff --git a/Pipfile b/Pipfile index 6fff2223e..8ab5d230e 100644 --- a/Pipfile +++ b/Pipfile @@ -8,12 +8,12 @@ aio-pika = "~=6.1" aiodns = "~=2.0" aiohttp = "~=3.5" aioredis = "~=1.3.1" +"async-rediscache[fakeredis]" = "~=0.1.2" beautifulsoup4 = "~=4.9" colorama = {version = "~=0.4.3",sys_platform = "== 'win32'"} coloredlogs = "~=14.0" deepdiff = "~=4.0" discord.py = "~=1.4.0" -fakeredis = "~=1.4" feedparser = "~=5.2" fuzzywuzzy = "~=0.17" lxml = "~=4.4" diff --git a/Pipfile.lock b/Pipfile.lock index 50ddd478c..97436413d 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "1905fd7eb15074ddbf04f2177b6cdd65edc4c74cb5fcbf4e6ca08ef649ba8a3c" + "sha256": "63eec3557c8bfd42191cb5a7525d6e298471c16863adff485d917e08b72cd787" }, "pipfile-spec": 6, "requires": { @@ -18,11 +18,11 @@ "default": { "aio-pika": { "hashes": [ - "sha256:c4cbbeb85b3c7bf81bc127371846cd949e6231717ce1e6ac7ee1dd5ede21f866", - "sha256:ec7fef24f588d90314873463ab4f2c3debce0bd8830e49e3786586be96bc2e8e" + "sha256:4a20d4d941e1f113a950ea529a90bd9159c8d7aafaa1c71e9c707c8c2b526ea6", + "sha256:7bf3f183df1eb348d007210a0c1a3c5c755f1b3def1a9a395e93f30b91da1daf" ], "index": "pypi", - "version": "==6.6.1" + "version": "==6.7.0" }, "aiodns": { "hashes": [ @@ -73,6 +73,18 @@ ], "version": "==0.7.12" }, + "async-rediscache": { + "extras": [ + "fakeredis" + ], + "hashes": [ + "sha256:407aed1aad97bf22f690eca5369806d22eefc8ca104a52c1f1bd47dd6db45fc2", + "sha256:563aaff79ec611a92a0ad78e39ff159e3a4b4cf0bea41e061de5f3701a17d50c" + ], + "index": "pypi", + "markers": "python_version ~= '3.7'", + "version": "==0.1.2" + }, "async-timeout": { "hashes": [ "sha256:0c3c816a028d47f659d6ff5c745cb2acf1f966da1fe5c19c77a70282b25f4c5f", @@ -83,11 +95,11 @@ }, "attrs": { "hashes": [ - "sha256:08a96c641c3a74e44eb59afb61a24f2cb9f4d7188748e76ba4bb5edfa3cb7d1c", - "sha256:f7b7ce16570fe9965acd6d30101a28f62fb4a7f9e926b3bbc9b61f8b04247e72" + "sha256:26b54ddbbb9ee1d34d5d3668dd37d6cf74990ab23c828c2888dccdceee395594", + "sha256:fce7fc47dfc976152e82d53ff92fa0407700c21acd20886a13777a0d20e655dc" ], "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", - "version": "==19.3.0" + "version": "==20.2.0" }, "babel": { "hashes": [ @@ -115,36 +127,36 @@ }, "cffi": { "hashes": [ - "sha256:267adcf6e68d77ba154334a3e4fc921b8e63cbb38ca00d33d40655d4228502bc", - "sha256:26f33e8f6a70c255767e3c3f957ccafc7f1f706b966e110b855bfe944511f1f9", - "sha256:3cd2c044517f38d1b577f05927fb9729d3396f1d44d0c659a445599e79519792", - "sha256:4a03416915b82b81af5502459a8a9dd62a3c299b295dcdf470877cb948d655f2", - "sha256:4ce1e995aeecf7cc32380bc11598bfdfa017d592259d5da00fc7ded11e61d022", - "sha256:4f53e4128c81ca3212ff4cf097c797ab44646a40b42ec02a891155cd7a2ba4d8", - "sha256:4fa72a52a906425416f41738728268072d5acfd48cbe7796af07a923236bcf96", - "sha256:66dd45eb9530e3dde8f7c009f84568bc7cac489b93d04ac86e3111fb46e470c2", - "sha256:6923d077d9ae9e8bacbdb1c07ae78405a9306c8fd1af13bfa06ca891095eb995", - "sha256:833401b15de1bb92791d7b6fb353d4af60dc688eaa521bd97203dcd2d124a7c1", - "sha256:8416ed88ddc057bab0526d4e4e9f3660f614ac2394b5e019a628cdfff3733849", - "sha256:892daa86384994fdf4856cb43c93f40cbe80f7f95bb5da94971b39c7f54b3a9c", - "sha256:98be759efdb5e5fa161e46d404f4e0ce388e72fbf7d9baf010aff16689e22abe", - "sha256:a6d28e7f14ecf3b2ad67c4f106841218c8ab12a0683b1528534a6c87d2307af3", - "sha256:b1d6ebc891607e71fd9da71688fcf332a6630b7f5b7f5549e6e631821c0e5d90", - "sha256:b2a2b0d276a136146e012154baefaea2758ef1f56ae9f4e01c612b0831e0bd2f", - "sha256:b87dfa9f10a470eee7f24234a37d1d5f51e5f5fa9eeffda7c282e2b8f5162eb1", - "sha256:bac0d6f7728a9cc3c1e06d4fcbac12aaa70e9379b3025b27ec1226f0e2d404cf", - "sha256:c991112622baee0ae4d55c008380c32ecfd0ad417bcd0417ba432e6ba7328caa", - "sha256:cda422d54ee7905bfc53ee6915ab68fe7b230cacf581110df4272ee10462aadc", - "sha256:d3148b6ba3923c5850ea197a91a42683f946dba7e8eb82dfa211ab7e708de939", - "sha256:d6033b4ffa34ef70f0b8086fd4c3df4bf801fee485a8a7d4519399818351aa8e", - "sha256:ddff0b2bd7edcc8c82d1adde6dbbf5e60d57ce985402541cd2985c27f7bec2a0", - "sha256:e23cb7f1d8e0f93addf0cae3c5b6f00324cccb4a7949ee558d7b6ca973ab8ae9", - "sha256:effd2ba52cee4ceff1a77f20d2a9f9bf8d50353c854a282b8760ac15b9833168", - "sha256:f90c2267101010de42f7273c94a1f026e56cbc043f9330acd8a80e64300aba33", - "sha256:f960375e9823ae6a07072ff7f8a85954e5a6434f97869f50d0e41649a1c8144f", - "sha256:fcf32bf76dc25e30ed793145a57426064520890d7c02866eb93d3e4abe516948" - ], - "version": "==1.14.1" + "sha256:0da50dcbccd7cb7e6c741ab7912b2eff48e85af217d72b57f80ebc616257125e", + "sha256:12a453e03124069b6896107ee133ae3ab04c624bb10683e1ed1c1663df17c13c", + "sha256:15419020b0e812b40d96ec9d369b2bc8109cc3295eac6e013d3261343580cc7e", + "sha256:15a5f59a4808f82d8ec7364cbace851df591c2d43bc76bcbe5c4543a7ddd1bf1", + "sha256:23e44937d7695c27c66a54d793dd4b45889a81b35c0751ba91040fe825ec59c4", + "sha256:29c4688ace466a365b85a51dcc5e3c853c1d283f293dfcc12f7a77e498f160d2", + "sha256:57214fa5430399dffd54f4be37b56fe22cedb2b98862550d43cc085fb698dc2c", + "sha256:577791f948d34d569acb2d1add5831731c59d5a0c50a6d9f629ae1cefd9ca4a0", + "sha256:6539314d84c4d36f28d73adc1b45e9f4ee2a89cdc7e5d2b0a6dbacba31906798", + "sha256:65867d63f0fd1b500fa343d7798fa64e9e681b594e0a07dc934c13e76ee28fb1", + "sha256:672b539db20fef6b03d6f7a14b5825d57c98e4026401fce838849f8de73fe4d4", + "sha256:6843db0343e12e3f52cc58430ad559d850a53684f5b352540ca3f1bc56df0731", + "sha256:7057613efefd36cacabbdbcef010e0a9c20a88fc07eb3e616019ea1692fa5df4", + "sha256:76ada88d62eb24de7051c5157a1a78fd853cca9b91c0713c2e973e4196271d0c", + "sha256:837398c2ec00228679513802e3744d1e8e3cb1204aa6ad408b6aff081e99a487", + "sha256:8662aabfeab00cea149a3d1c2999b0731e70c6b5bac596d95d13f643e76d3d4e", + "sha256:95e9094162fa712f18b4f60896e34b621df99147c2cee216cfa8f022294e8e9f", + "sha256:99cc66b33c418cd579c0f03b77b94263c305c389cb0c6972dac420f24b3bf123", + "sha256:9b219511d8b64d3fa14261963933be34028ea0e57455baf6781fe399c2c3206c", + "sha256:ae8f34d50af2c2154035984b8b5fc5d9ed63f32fe615646ab435b05b132ca91b", + "sha256:b9aa9d8818c2e917fa2c105ad538e222a5bce59777133840b93134022a7ce650", + "sha256:bf44a9a0141a082e89c90e8d785b212a872db793a0080c20f6ae6e2a0ebf82ad", + "sha256:c0b48b98d79cf795b0916c57bebbc6d16bb43b9fc9b8c9f57f4cf05881904c75", + "sha256:da9d3c506f43e220336433dffe643fbfa40096d408cb9b7f2477892f369d5f82", + "sha256:e4082d832e36e7f9b2278bc774886ca8207346b99f278e54c9de4834f17232f7", + "sha256:e4b9b7af398c32e408c00eb4e0d33ced2f9121fd9fb978e6c1b57edd014a7d15", + "sha256:e613514a82539fc48291d01933951a13ae93b6b444a88782480be32245ed4afa", + "sha256:f5033952def24172e60493b68717792e3aebb387a8d186c43c020d9363ee7281" + ], + "version": "==1.14.2" }, "chardet": { "hashes": [ @@ -188,11 +200,11 @@ }, "discord.py": { "hashes": [ - "sha256:2b1846bfa382b54f4eace8e437a9f59f185388c5b08749ac0e1bbd98e05bfde5", - "sha256:f3db9531fccc391f51de65cfa46133106a9ba12ff2927aca6c14bffd3b7f17b5" + "sha256:98ea3096a3585c9c379209926f530808f5fcf4930928d8cfb579d2562d119570", + "sha256:f9decb3bfa94613d922376288617e6a6f969260923643e2897f4540c34793442" ], "markers": "python_full_version >= '3.5.3'", - "version": "==1.4.0" + "version": "==1.4.1" }, "docutils": { "hashes": [ @@ -204,11 +216,10 @@ }, "fakeredis": { "hashes": [ - "sha256:790c85ad0f3b2967aba1f51767021bc59760fcb612159584be018ea7384f7fd2", - "sha256:fdfe06f277092d022c271fcaefdc1f0c8d9bfa8cb15374cae41d66a20bd96d2b" + "sha256:7ea0866ba5edb40fe2e9b1722535df0c7e6b91d518aa5f50d96c2fff3ea7f4c2", + "sha256:aad8836ffe0319ffbba66dcf872ac6e7e32d1f19790e31296ba58445efb0a5c7" ], - "index": "pypi", - "version": "==1.4.2" + "version": "==1.4.3" }, "feedparser": { "hashes": [ @@ -350,10 +361,11 @@ }, "markdownify": { "hashes": [ - "sha256:28ce67d1888e4908faaab7b04d2193cda70ea4f902f156a21d0aaea55e63e0a1" + "sha256:30be8340724e706c9e811c27fe8c1542cf74a15b46827924fff5c54b40dd9b0d", + "sha256:a69588194fd76634f0139d6801b820fd652dc5eeba9530e90d323dfdc0155252" ], "index": "pypi", - "version": "==0.4.1" + "version": "==0.5.3" }, "markupsafe": { "hashes": [ @@ -396,11 +408,11 @@ }, "more-itertools": { "hashes": [ - "sha256:68c70cc7167bdf5c7c9d8f6954a7837089c6a36bf565383919bb595efb8a17e5", - "sha256:b78134b2063dd214000685165d81c154522c3ee0a1c0d4d113c80361c234c5a2" + "sha256:6f83822ae94818eae2612063a5101a7311e68ae8002005b5e05f03fd74a86a20", + "sha256:9b30f12df9393f0d28af9210ff8efe48d10c94f73e5daf886f10c4b0b0b4f03c" ], "index": "pypi", - "version": "==8.4.0" + "version": "==8.5.0" }, "multidict": { "hashes": [ @@ -491,11 +503,11 @@ }, "pygments": { "hashes": [ - "sha256:647344a061c249a3b74e230c739f434d7ea4d8b1d5f3721bc0f3558049b38f44", - "sha256:ff7a40b4860b727ab48fad6360eb351cc1b33cbf9b15a0f689ca5353e9463324" + "sha256:307543fe65c0947b126e83dd5a61bd8acbd84abec11f43caebaf5534cbc17998", + "sha256:926c3f319eda178d1bd90851e4317e6d8cdb5e292a3386aac9bd75eca29cf9c7" ], "markers": "python_version >= '3.5'", - "version": "==2.6.1" + "version": "==2.7.1" }, "pyparsing": { "hashes": [ @@ -555,11 +567,11 @@ }, "sentry-sdk": { "hashes": [ - "sha256:21b17d6aa064c0fb703a7c00f77cf6c9c497cf2f83345c28892980a5e742d116", - "sha256:4fc97114c77d005467b9b1a29f042e2bc01923cb683b0ef0bbda46e79fa12532" + "sha256:1a086486ff9da15791f294f6e9915eb3747d161ef64dee2d038a4d0b4a369b24", + "sha256:45486deb031cea6bbb25a540d7adb4dd48cd8a1cc31e6a5ce9fb4f792a572e9a" ], "index": "pypi", - "version": "==0.16.3" + "version": "==0.17.6" }, "six": { "hashes": [ @@ -697,11 +709,11 @@ }, "attrs": { "hashes": [ - "sha256:08a96c641c3a74e44eb59afb61a24f2cb9f4d7188748e76ba4bb5edfa3cb7d1c", - "sha256:f7b7ce16570fe9965acd6d30101a28f62fb4a7f9e926b3bbc9b61f8b04247e72" + "sha256:26b54ddbbb9ee1d34d5d3668dd37d6cf74990ab23c828c2888dccdceee395594", + "sha256:fce7fc47dfc976152e82d53ff92fa0407700c21acd20886a13777a0d20e655dc" ], "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", - "version": "==19.3.0" + "version": "==20.2.0" }, "cfgv": { "hashes": [ @@ -713,43 +725,43 @@ }, "coverage": { "hashes": [ - "sha256:098a703d913be6fbd146a8c50cc76513d726b022d170e5e98dc56d958fd592fb", - "sha256:16042dc7f8e632e0dcd5206a5095ebd18cb1d005f4c89694f7f8aafd96dd43a3", - "sha256:1adb6be0dcef0cf9434619d3b892772fdb48e793300f9d762e480e043bd8e716", - "sha256:27ca5a2bc04d68f0776f2cdcb8bbd508bbe430a7bf9c02315cd05fb1d86d0034", - "sha256:28f42dc5172ebdc32622a2c3f7ead1b836cdbf253569ae5673f499e35db0bac3", - "sha256:2fcc8b58953d74d199a1a4d633df8146f0ac36c4e720b4a1997e9b6327af43a8", - "sha256:304fbe451698373dc6653772c72c5d5e883a4aadaf20343592a7abb2e643dae0", - "sha256:30bc103587e0d3df9e52cd9da1dd915265a22fad0b72afe54daf840c984b564f", - "sha256:40f70f81be4d34f8d491e55936904db5c527b0711b2a46513641a5729783c2e4", - "sha256:4186fc95c9febeab5681bc3248553d5ec8c2999b8424d4fc3a39c9cba5796962", - "sha256:46794c815e56f1431c66d81943fa90721bb858375fb36e5903697d5eef88627d", - "sha256:4869ab1c1ed33953bb2433ce7b894a28d724b7aa76c19b11e2878034a4e4680b", - "sha256:4f6428b55d2916a69f8d6453e48a505c07b2245653b0aa9f0dee38785939f5e4", - "sha256:52f185ffd3291196dc1aae506b42e178a592b0b60a8610b108e6ad892cfc1bb3", - "sha256:538f2fd5eb64366f37c97fdb3077d665fa946d2b6d95447622292f38407f9258", - "sha256:64c4f340338c68c463f1b56e3f2f0423f7b17ba6c3febae80b81f0e093077f59", - "sha256:675192fca634f0df69af3493a48224f211f8db4e84452b08d5fcebb9167adb01", - "sha256:700997b77cfab016533b3e7dbc03b71d33ee4df1d79f2463a318ca0263fc29dd", - "sha256:8505e614c983834239f865da2dd336dcf9d72776b951d5dfa5ac36b987726e1b", - "sha256:962c44070c281d86398aeb8f64e1bf37816a4dfc6f4c0f114756b14fc575621d", - "sha256:9e536783a5acee79a9b308be97d3952b662748c4037b6a24cbb339dc7ed8eb89", - "sha256:9ea749fd447ce7fb1ac71f7616371f04054d969d412d37611716721931e36efd", - "sha256:a34cb28e0747ea15e82d13e14de606747e9e484fb28d63c999483f5d5188e89b", - "sha256:a3ee9c793ffefe2944d3a2bd928a0e436cd0ac2d9e3723152d6fd5398838ce7d", - "sha256:aab75d99f3f2874733946a7648ce87a50019eb90baef931698f96b76b6769a46", - "sha256:b1ed2bdb27b4c9fc87058a1cb751c4df8752002143ed393899edb82b131e0546", - "sha256:b360d8fd88d2bad01cb953d81fd2edd4be539df7bfec41e8753fe9f4456a5082", - "sha256:b8f58c7db64d8f27078cbf2a4391af6aa4e4767cc08b37555c4ae064b8558d9b", - "sha256:c1bbb628ed5192124889b51204de27c575b3ffc05a5a91307e7640eff1d48da4", - "sha256:c2ff24df02a125b7b346c4c9078c8936da06964cc2d276292c357d64378158f8", - "sha256:c890728a93fffd0407d7d37c1e6083ff3f9f211c83b4316fae3778417eab9811", - "sha256:c96472b8ca5dc135fb0aa62f79b033f02aa434fb03a8b190600a5ae4102df1fd", - "sha256:ce7866f29d3025b5b34c2e944e66ebef0d92e4a4f2463f7266daa03a1332a651", - "sha256:e26c993bd4b220429d4ec8c1468eca445a4064a61c74ca08da7429af9bc53bb0" - ], - "index": "pypi", - "version": "==5.2.1" + "sha256:0203acd33d2298e19b57451ebb0bed0ab0c602e5cf5a818591b4918b1f97d516", + "sha256:0f313707cdecd5cd3e217fc68c78a960b616604b559e9ea60cc16795c4304259", + "sha256:1c6703094c81fa55b816f5ae542c6ffc625fec769f22b053adb42ad712d086c9", + "sha256:1d44bb3a652fed01f1f2c10d5477956116e9b391320c94d36c6bf13b088a1097", + "sha256:280baa8ec489c4f542f8940f9c4c2181f0306a8ee1a54eceba071a449fb870a0", + "sha256:29a6272fec10623fcbe158fdf9abc7a5fa032048ac1d8631f14b50fbfc10d17f", + "sha256:2b31f46bf7b31e6aa690d4c7a3d51bb262438c6dcb0d528adde446531d0d3bb7", + "sha256:2d43af2be93ffbad25dd959899b5b809618a496926146ce98ee0b23683f8c51c", + "sha256:381ead10b9b9af5f64646cd27107fb27b614ee7040bb1226f9c07ba96625cbb5", + "sha256:47a11bdbd8ada9b7ee628596f9d97fbd3851bd9999d398e9436bd67376dbece7", + "sha256:4d6a42744139a7fa5b46a264874a781e8694bb32f1d76d8137b68138686f1729", + "sha256:50691e744714856f03a86df3e2bff847c2acede4c191f9a1da38f088df342978", + "sha256:530cc8aaf11cc2ac7430f3614b04645662ef20c348dce4167c22d99bec3480e9", + "sha256:582ddfbe712025448206a5bc45855d16c2e491c2dd102ee9a2841418ac1c629f", + "sha256:63808c30b41f3bbf65e29f7280bf793c79f54fb807057de7e5238ffc7cc4d7b9", + "sha256:71b69bd716698fa62cd97137d6f2fdf49f534decb23a2c6fc80813e8b7be6822", + "sha256:7858847f2d84bf6e64c7f66498e851c54de8ea06a6f96a32a1d192d846734418", + "sha256:78e93cc3571fd928a39c0b26767c986188a4118edc67bc0695bc7a284da22e82", + "sha256:7f43286f13d91a34fadf61ae252a51a130223c52bfefb50310d5b2deb062cf0f", + "sha256:86e9f8cd4b0cdd57b4ae71a9c186717daa4c5a99f3238a8723f416256e0b064d", + "sha256:8f264ba2701b8c9f815b272ad568d555ef98dfe1576802ab3149c3629a9f2221", + "sha256:9342dd70a1e151684727c9c91ea003b2fb33523bf19385d4554f7897ca0141d4", + "sha256:9361de40701666b034c59ad9e317bae95c973b9ff92513dd0eced11c6adf2e21", + "sha256:9669179786254a2e7e57f0ecf224e978471491d660aaca833f845b72a2df3709", + "sha256:aac1ba0a253e17889550ddb1b60a2063f7474155465577caa2a3b131224cfd54", + "sha256:aef72eae10b5e3116bac6957de1df4d75909fc76d1499a53fb6387434b6bcd8d", + "sha256:bd3166bb3b111e76a4f8e2980fa1addf2920a4ca9b2b8ca36a3bc3dedc618270", + "sha256:c1b78fb9700fc961f53386ad2fd86d87091e06ede5d118b8a50dea285a071c24", + "sha256:c3888a051226e676e383de03bf49eb633cd39fc829516e5334e69b8d81aae751", + "sha256:c5f17ad25d2c1286436761b462e22b5020d83316f8e8fcb5deb2b3151f8f1d3a", + "sha256:c851b35fc078389bc16b915a0a7c1d5923e12e2c5aeec58c52f4aa8085ac8237", + "sha256:cb7df71de0af56000115eafd000b867d1261f786b5eebd88a0ca6360cccfaca7", + "sha256:cedb2f9e1f990918ea061f28a0f0077a07702e3819602d3507e2ff98c8d20636", + "sha256:e8caf961e1b1a945db76f1b5fa9c91498d15f545ac0ababbe575cfab185d3bd8" + ], + "index": "pypi", + "version": "==5.3" }, "distlib": { "hashes": [ @@ -775,11 +787,11 @@ }, "flake8-annotations": { "hashes": [ - "sha256:7816a5d8f65ffdf37b8e21e5b17e0fd1e492aa92638573276de066e889a22b26", - "sha256:8d18db74a750dd97f40b483cc3ef80d07d03f687525bad8fd83365dcd3bfd414" + "sha256:09fe1aa3f40cb8fef632a0ab3614050a7584bb884b6134e70cf1fc9eeee642fa", + "sha256:5bda552f074fd6e34276c7761756fa07d824ffac91ce9c0a8555eb2bc5b92d7a" ], "index": "pypi", - "version": "==2.3.0" + "version": "==2.4.0" }, "flake8-bugbear": { "hashes": [ @@ -837,11 +849,11 @@ }, "identify": { "hashes": [ - "sha256:110ed090fec6bce1aabe3c72d9258a9de82207adeaa5a05cd75c635880312f9a", - "sha256:ccd88716b890ecbe10920659450a635d2d25de499b9a638525a48b48261d989b" + "sha256:c770074ae1f19e08aadbda1c886bc6d0cb55ffdc503a8c0fe8699af2fc9664ae", + "sha256:d02d004568c5a01261839a05e91705e3e9f5c57a3551648f9b3fb2b9c62c0f62" ], "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", - "version": "==1.4.25" + "version": "==1.5.3" }, "mccabe": { "hashes": [ @@ -852,9 +864,10 @@ }, "nodeenv": { "hashes": [ - "sha256:4b0b77afa3ba9b54f4b6396e60b0c83f59eaeb2d63dc3cc7a70f7f4af96c82bc" + "sha256:5304d424c529c997bc888453aeaa6362d242b6b4631e90f3d4bf1b290f1c84a9", + "sha256:ab45090ae383b716c4ef89e690c41ff8c2b257b85b309f01f3654df3d084bd7c" ], - "version": "==1.4.0" + "version": "==1.5.0" }, "pep8-naming": { "hashes": [ @@ -866,11 +879,11 @@ }, "pre-commit": { "hashes": [ - "sha256:1657663fdd63a321a4a739915d7d03baedd555b25054449090f97bb0cb30a915", - "sha256:e8b1315c585052e729ab7e99dcca5698266bedce9067d21dc909c23e3ceed626" + "sha256:810aef2a2ba4f31eed1941fc270e72696a1ad5590b9751839c90807d0fff6b9a", + "sha256:c54fd3e574565fe128ecc5e7d2f91279772ddb03f8729645fa812fe809084a70" ], "index": "pypi", - "version": "==2.6.0" + "version": "==2.7.1" }, "pycodestyle": { "hashes": [ @@ -882,11 +895,11 @@ }, "pydocstyle": { "hashes": [ - "sha256:da7831660b7355307b32778c4a0dbfb137d89254ef31a2b2978f50fc0b4d7586", - "sha256:f4f5d210610c2d153fae39093d44224c17429e2ad7da12a8b419aba5c2f614b5" + "sha256:19b86fa8617ed916776a11cd8bc0197e5b9856d5433b777f51a3defe13075325", + "sha256:aca749e190a01726a4fb472dd4ef23b5c9da7b9205c0a7857c06533de13fd678" ], "markers": "python_version >= '3.5'", - "version": "==5.0.2" + "version": "==5.1.1" }, "pyflakes": { "hashes": [ @@ -937,19 +950,19 @@ }, "unittest-xml-reporting": { "hashes": [ - "sha256:74eaf7739a7957a74f52b8187c5616f61157372189bef0a32ba5c30bbc00e58a", - "sha256:e09b8ae70cce9904cdd331f53bf929150962869a5324ab7ff3dd6c8b87e01f7d" + "sha256:7bf515ea8cb244255a25100cd29db611a73f8d3d0aaf672ed3266307e14cc1ca", + "sha256:984cebba69e889401bfe3adb9088ca376b3a1f923f0590d005126c1bffd1a695" ], "index": "pypi", - "version": "==3.0.2" + "version": "==3.0.4" }, "virtualenv": { "hashes": [ - "sha256:7b54fd606a1b85f83de49ad8d80dbec08e983a2d2f96685045b262ebc7481ee5", - "sha256:8cd7b2a4850b003a11be2fc213e206419efab41115cc14bca20e69654f2ac08e" + "sha256:43add625c53c596d38f971a465553f6318decc39d98512bc100fa1b1e839c8dc", + "sha256:e0305af10299a7fb0d69393d8f04cb2965dda9351140d11ac8db4e5e3970451b" ], "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", - "version": "==20.0.30" + "version": "==20.0.31" } } } diff --git a/bot/__main__.py b/bot/__main__.py index fe2cf90e6..ee627be0a 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -1,7 +1,9 @@ +import asyncio import logging import discord import sentry_sdk +from async_rediscache import RedisSession from discord.ext.commands import when_mentioned_or from sentry_sdk.integrations.aiohttp import AioHttpIntegration from sentry_sdk.integrations.logging import LoggingIntegration @@ -24,8 +26,26 @@ sentry_sdk.init( ] ) +# Create the redis session instance. +redis_session = RedisSession( + address=(constants.Redis.host, constants.Redis.port), + password=constants.Redis.password, + minsize=1, + maxsize=20, + use_fakeredis=constants.Redis.use_fakeredis, +) + +# Connect redis session to ensure it's connected before we try to access Redis +# from somewhere within the bot. We create the event loop in the same way +# discord.py normally does and pass it to the bot's __init__. +loop = asyncio.get_event_loop() +loop.run_until_complete(redis_session.connect()) + + allowed_roles = [discord.Object(id_) for id_ in constants.MODERATION_ROLES] bot = Bot( + redis_session=redis_session, + loop=loop, command_prefix=when_mentioned_or(constants.Bot.prefix), activity=discord.Game(name="Commands: !help"), case_insensitive=True, diff --git a/bot/bot.py b/bot/bot.py index d25074fd9..b2e5237fe 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -6,9 +6,8 @@ from collections import defaultdict from typing import Dict, Optional import aiohttp -import aioredis import discord -import fakeredis.aioredis +from async_rediscache import RedisSession from discord.ext import commands from sentry_sdk import push_scope @@ -21,7 +20,7 @@ log = logging.getLogger('bot') class Bot(commands.Bot): """A subclass of `discord.ext.commands.Bot` with an aiohttp session and an API client.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args, redis_session: RedisSession, **kwargs): if "connector" in kwargs: warnings.warn( "If login() is called (or the bot is started), the connector will be overwritten " @@ -31,9 +30,7 @@ class Bot(commands.Bot): super().__init__(*args, **kwargs) self.http_session: Optional[aiohttp.ClientSession] = None - self.redis_session: Optional[aioredis.Redis] = None - self.redis_ready = asyncio.Event() - self.redis_closed = False + self.redis_session = redis_session self.api_client = api.APIClient(loop=self.loop) self.filter_list_cache = defaultdict(dict) @@ -58,30 +55,6 @@ class Bot(commands.Bot): for item in full_cache: self.insert_item_into_filter_list_cache(item) - async def _create_redis_session(self) -> None: - """ - Create the Redis connection pool, and then open the redis event gate. - - If constants.Redis.use_fakeredis is True, we'll set up a fake redis pool instead - of attempting to communicate with a real Redis server. This is useful because it - means contributors don't necessarily need to get Redis running locally just - to run the bot. - - The fakeredis cache won't have persistence across restarts, but that - usually won't matter for local bot testing. - """ - if constants.Redis.use_fakeredis: - log.info("Using fakeredis instead of communicating with a real Redis server.") - self.redis_session = await fakeredis.aioredis.create_redis_pool() - else: - self.redis_session = await aioredis.create_redis_pool( - address=(constants.Redis.host, constants.Redis.port), - password=constants.Redis.password, - ) - - self.redis_closed = False - self.redis_ready.set() - def _recreate(self) -> None: """Re-create the connector, aiohttp session, the APIClient and the Redis session.""" # Use asyncio for DNS resolution instead of threads so threads aren't spammed. @@ -94,13 +67,10 @@ class Bot(commands.Bot): "The previous connector was not closed; it will remain open and be overwritten" ) - if self.redis_session and not self.redis_session.closed: - log.warning( - "The previous redis pool was not closed; it will remain open and be overwritten" - ) - - # Create the redis session - self.loop.create_task(self._create_redis_session()) + if self.redis_session.closed: + # If the RedisSession was somehow closed, we try to reconnect it + # here. Normally, this shouldn't happen. + self.loop.create_task(self.redis_session.connect()) # Use AF_INET as its socket family to prevent HTTPS related problems both locally # and in production. @@ -180,10 +150,7 @@ class Bot(commands.Bot): self.stats._transport.close() if self.redis_session: - self.redis_closed = True - self.redis_session.close() - self.redis_ready.clear() - await self.redis_session.wait_closed() + await self.redis_session.close() def insert_item_into_filter_list_cache(self, item: Dict[str, str]) -> None: """Add an item to the bots filter_list_cache.""" diff --git a/bot/cogs/dm_relay.py b/bot/cogs/dm_relay.py index 0d8f340b4..368d6a5d9 100644 --- a/bot/cogs/dm_relay.py +++ b/bot/cogs/dm_relay.py @@ -2,6 +2,7 @@ import logging from typing import Optional import discord +from async_rediscache import RedisCache from discord import Color from discord.ext import commands from discord.ext.commands import Cog @@ -9,7 +10,6 @@ from discord.ext.commands import Cog from bot import constants from bot.bot import Bot from bot.converters import UserMentionOrID -from bot.utils import RedisCache from bot.utils.checks import in_whitelist_check, with_role_check from bot.utils.messages import send_attachments from bot.utils.webhooks import send_webhook diff --git a/bot/cogs/filtering.py b/bot/cogs/filtering.py index 99b659bff..0950ace60 100644 --- a/bot/cogs/filtering.py +++ b/bot/cogs/filtering.py @@ -6,6 +6,7 @@ from typing import List, Mapping, Optional, Tuple, Union import dateutil import discord.errors +from async_rediscache import RedisCache from dateutil.relativedelta import relativedelta from discord import Colour, HTTPException, Member, Message, NotFound, TextChannel from discord.ext.commands import Cog @@ -16,9 +17,8 @@ from bot.bot import Bot from bot.cogs.moderation import ModLog from bot.constants import ( Channels, Colours, - Filter, Icons, URLs + Filter, Icons, URLs, ) -from bot.utils.redis_cache import RedisCache from bot.utils.regex import INVITE_RE from bot.utils.scheduling import Scheduler diff --git a/bot/cogs/help_channels.py b/bot/cogs/help_channels.py index 0f9cac89e..d4feb636c 100644 --- a/bot/cogs/help_channels.py +++ b/bot/cogs/help_channels.py @@ -9,11 +9,11 @@ from pathlib import Path import discord import discord.abc +from async_rediscache import RedisCache from discord.ext import commands from bot import constants from bot.bot import Bot -from bot.utils import RedisCache from bot.utils.checks import with_role_check from bot.utils.scheduling import Scheduler diff --git a/bot/cogs/verification.py b/bot/cogs/verification.py index 9ae92a228..859cc26e1 100644 --- a/bot/cogs/verification.py +++ b/bot/cogs/verification.py @@ -5,6 +5,7 @@ from contextlib import suppress from datetime import datetime, timedelta import discord +from async_rediscache import RedisCache from discord.ext import tasks from discord.ext.commands import Cog, Context, command, group from discord.utils import snowflake_time @@ -14,7 +15,6 @@ from bot.bot import Bot from bot.cogs.moderation import ModLog from bot.decorators import in_whitelist, with_role, without_role from bot.utils.checks import InWhitelistCheckFailure, without_role_check -from bot.utils.redis_cache import RedisCache log = logging.getLogger(__name__) diff --git a/tests/helpers.py b/tests/helpers.py index facc4e1af..e47fdf28f 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -308,7 +308,11 @@ class MockBot(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.ext.commands.Bot` instances. For more information, see the `MockGuild` docstring. """ - spec_set = Bot(command_prefix=unittest.mock.MagicMock(), loop=_get_mock_loop()) + spec_set = Bot( + command_prefix=unittest.mock.MagicMock(), + loop=_get_mock_loop(), + redis_session=unittest.mock.MagicMock(), + ) additional_spec_asyncs = ("wait_for", "redis_ready") def __init__(self, **kwargs) -> None: -- cgit v1.2.3 From de4a8d960f9f845467efb470e17a0e9685df1bdc Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff Date: Sat, 19 Sep 2020 21:48:41 +0200 Subject: Remove vestigial RedisCache class definition As we're now using the `async-rediscache` package, it's no longer necessary to keep the `RedisCache` defined in `bot.utils.redis_cache` around. I've removed the file containing it and the tests written for it. Signed-off-by: Sebastiaan Zeeff --- bot/utils/__init__.py | 3 +- bot/utils/redis_cache.py | 414 ------------------------------------ tests/bot/utils/test_redis_cache.py | 265 ----------------------- 3 files changed, 1 insertion(+), 681 deletions(-) delete mode 100644 bot/utils/redis_cache.py delete mode 100644 tests/bot/utils/test_redis_cache.py (limited to 'tests') diff --git a/bot/utils/__init__.py b/bot/utils/__init__.py index 3e93fcb06..60170a88f 100644 --- a/bot/utils/__init__.py +++ b/bot/utils/__init__.py @@ -1,5 +1,4 @@ from bot.utils.helpers import CogABCMeta, find_nth_occurrence, pad_base64 -from bot.utils.redis_cache import RedisCache from bot.utils.services import send_to_paste_service -__all__ = ['RedisCache', 'CogABCMeta', 'find_nth_occurrence', 'pad_base64', 'send_to_paste_service'] +__all__ = ['CogABCMeta', 'find_nth_occurrence', 'pad_base64', 'send_to_paste_service'] diff --git a/bot/utils/redis_cache.py b/bot/utils/redis_cache.py deleted file mode 100644 index 52b689b49..000000000 --- a/bot/utils/redis_cache.py +++ /dev/null @@ -1,414 +0,0 @@ -from __future__ import annotations - -import asyncio -import logging -from functools import partialmethod -from typing import Any, Dict, ItemsView, Optional, Tuple, Union - -from bot.bot import Bot - -log = logging.getLogger(__name__) - -# Type aliases -RedisKeyType = Union[str, int] -RedisValueType = Union[str, int, float, bool] -RedisKeyOrValue = Union[RedisKeyType, RedisValueType] - -# Prefix tuples -_PrefixTuple = Tuple[Tuple[str, Any], ...] -_VALUE_PREFIXES = ( - ("f|", float), - ("i|", int), - ("s|", str), - ("b|", bool), -) -_KEY_PREFIXES = ( - ("i|", int), - ("s|", str), -) - - -class NoBotInstanceError(RuntimeError): - """Raised when RedisCache is created without an available bot instance on the owner class.""" - - -class NoNamespaceError(RuntimeError): - """Raised when RedisCache has no namespace, for example if it is not assigned to a class attribute.""" - - -class NoParentInstanceError(RuntimeError): - """Raised when the parent instance is available, for example if called by accessing the parent class directly.""" - - -class RedisCache: - """ - A simplified interface for a Redis connection. - - We implement several convenient methods that are fairly similar to have a dict - behaves, and should be familiar to Python users. The biggest difference is that - all the public methods in this class are coroutines, and must be awaited. - - Because of limitations in Redis, this cache will only accept strings and integers for keys, - and strings, integers, floats and booleans for values. - - Please note that this class MUST be created as a class attribute, and that that class - must also contain an attribute with an instance of our Bot. See `__get__` and `__set_name__` - for more information about how this works. - - Simple example for how to use this: - - class SomeCog(Cog): - # To initialize a valid RedisCache, just add it as a class attribute here. - # Do not add it to the __init__ method or anywhere else, it MUST be a class - # attribute. Do not pass any parameters. - cache = RedisCache() - - async def my_method(self): - - # Now we're ready to use the RedisCache. - # One thing to note here is that this will not work unless - # we access self.cache through an _instance_ of this class. - # - # For example, attempting to use SomeCog.cache will _not_ work, - # you _must_ instantiate the class first and use that instance. - # - # Now we can store some stuff in the cache just by doing this. - # This data will persist through restarts! - await self.cache.set("key", "value") - - # To get the data, simply do this. - value = await self.cache.get("key") - - # Other methods work more or less like a dictionary. - # Checking if something is in the cache - await self.cache.contains("key") - - # iterating the cache - async for key, value in self.cache.items(): - print(value) - - # We can even iterate in a comprehension! - consumed = [value async for key, value in self.cache.items()] - """ - - _namespaces = [] - - def __init__(self) -> None: - """Initialize the RedisCache.""" - self._namespace = None - self.bot = None - self._increment_lock = None - - def _set_namespace(self, namespace: str) -> None: - """Try to set the namespace, but do not permit collisions.""" - log.trace(f"RedisCache setting namespace to {namespace}") - self._namespaces.append(namespace) - self._namespace = namespace - - @staticmethod - def _to_typestring(key_or_value: RedisKeyOrValue, prefixes: _PrefixTuple) -> str: - """Turn a valid Redis type into a typestring.""" - for prefix, _type in prefixes: - # Convert bools into integers before storing them. - if type(key_or_value) is bool: - bool_int = int(key_or_value) - return f"{prefix}{bool_int}" - - # isinstance is a bad idea here, because isintance(False, int) == True. - if type(key_or_value) is _type: - return f"{prefix}{key_or_value}" - - raise TypeError(f"RedisCache._to_typestring only supports the following: {prefixes}.") - - @staticmethod - def _from_typestring(key_or_value: Union[bytes, str], prefixes: _PrefixTuple) -> RedisKeyOrValue: - """Deserialize a typestring into a valid Redis type.""" - # Stuff that comes out of Redis will be bytestrings, so let's decode those. - if isinstance(key_or_value, bytes): - key_or_value = key_or_value.decode('utf-8') - - # Now we convert our unicode string back into the type it originally was. - for prefix, _type in prefixes: - if key_or_value.startswith(prefix): - - # For booleans, we need special handling because bool("False") is True. - if prefix == "b|": - value = key_or_value[len(prefix):] - return bool(int(value)) - - # Otherwise we can just convert normally. - return _type(key_or_value[len(prefix):]) - raise TypeError(f"RedisCache._from_typestring only supports the following: {prefixes}.") - - # Add some nice partials to call our generic typestring converters. - # These are basically methods that will fill in some of the parameters for you, so that - # any call to _key_to_typestring will be like calling _to_typestring with the two parameters - # at `prefixes` and `types_string` pre-filled. - # - # See https://docs.python.org/3/library/functools.html#functools.partialmethod - _key_to_typestring = partialmethod(_to_typestring, prefixes=_KEY_PREFIXES) - _value_to_typestring = partialmethod(_to_typestring, prefixes=_VALUE_PREFIXES) - _key_from_typestring = partialmethod(_from_typestring, prefixes=_KEY_PREFIXES) - _value_from_typestring = partialmethod(_from_typestring, prefixes=_VALUE_PREFIXES) - - def _dict_from_typestring(self, dictionary: Dict) -> Dict: - """Turns all contents of a dict into valid Redis types.""" - return {self._key_from_typestring(key): self._value_from_typestring(value) for key, value in dictionary.items()} - - def _dict_to_typestring(self, dictionary: Dict) -> Dict: - """Turns all contents of a dict into typestrings.""" - return {self._key_to_typestring(key): self._value_to_typestring(value) for key, value in dictionary.items()} - - async def _validate_cache(self) -> None: - """Validate that the RedisCache is ready to be used.""" - if self._namespace is None: - error_message = ( - "Critical error: RedisCache has no namespace. " - "This object must be initialized as a class attribute." - ) - log.error(error_message) - raise NoNamespaceError(error_message) - - if self.bot is None: - error_message = ( - "Critical error: RedisCache has no `Bot` instance. " - "This happens when the class RedisCache was created in doesn't " - "have a Bot instance. Please make sure that you're instantiating " - "the RedisCache inside a class that has a Bot instance attribute." - ) - log.error(error_message) - raise NoBotInstanceError(error_message) - - if not self.bot.redis_closed: - await self.bot.redis_ready.wait() - - def __set_name__(self, owner: Any, attribute_name: str) -> None: - """ - Set the namespace to Class.attribute_name. - - Called automatically when this class is constructed inside a class as an attribute. - - This class MUST be created as a class attribute in a class, otherwise it will raise - exceptions whenever a method is used. This is because it uses this method to create - a namespace like `MyCog.my_class_attribute` which is used as a hash name when we store - stuff in Redis, to prevent collisions. - """ - self._set_namespace(f"{owner.__name__}.{attribute_name}") - - def __get__(self, instance: RedisCache, owner: Any) -> RedisCache: - """ - This is called if the RedisCache is a class attribute, and is accessed. - - The class this object is instantiated in must contain an attribute with an - instance of Bot. This is because Bot contains our redis_session, which is - the mechanism by which we will communicate with the Redis server. - - Any attempt to use RedisCache in a class that does not have a Bot instance - will fail. It is mostly intended to be used inside of a Cog, although theoretically - it should work in any class that has a Bot instance. - """ - if self.bot: - return self - - if self._namespace is None: - error_message = "RedisCache must be a class attribute." - log.error(error_message) - raise NoNamespaceError(error_message) - - if instance is None: - error_message = ( - "You must access the RedisCache instance through the cog instance " - "before accessing it using the cog's class object." - ) - log.error(error_message) - raise NoParentInstanceError(error_message) - - for attribute in vars(instance).values(): - if isinstance(attribute, Bot): - self.bot = attribute - return self - else: - error_message = ( - "Critical error: RedisCache has no `Bot` instance. " - "This happens when the class RedisCache was created in doesn't " - "have a Bot instance. Please make sure that you're instantiating " - "the RedisCache inside a class that has a Bot instance attribute." - ) - log.error(error_message) - raise NoBotInstanceError(error_message) - - def __repr__(self) -> str: - """Return a beautiful representation of this object instance.""" - return f"RedisCache(namespace={self._namespace!r})" - - async def set(self, key: RedisKeyType, value: RedisValueType) -> None: - """Store an item in the Redis cache.""" - await self._validate_cache() - - # Convert to a typestring and then set it - key = self._key_to_typestring(key) - value = self._value_to_typestring(value) - - log.trace(f"Setting {key} to {value}.") - await self.bot.redis_session.hset(self._namespace, key, value) - - async def get(self, key: RedisKeyType, default: Optional[RedisValueType] = None) -> Optional[RedisValueType]: - """Get an item from the Redis cache.""" - await self._validate_cache() - key = self._key_to_typestring(key) - - log.trace(f"Attempting to retrieve {key}.") - value = await self.bot.redis_session.hget(self._namespace, key) - - if value is None: - log.trace(f"Value not found, returning default value {default}") - return default - else: - value = self._value_from_typestring(value) - log.trace(f"Value found, returning value {value}") - return value - - async def delete(self, key: RedisKeyType) -> None: - """ - Delete an item from the Redis cache. - - If we try to delete a key that does not exist, it will simply be ignored. - - See https://redis.io/commands/hdel for more info on how this works. - """ - await self._validate_cache() - key = self._key_to_typestring(key) - - log.trace(f"Attempting to delete {key}.") - return await self.bot.redis_session.hdel(self._namespace, key) - - async def contains(self, key: RedisKeyType) -> bool: - """ - Check if a key exists in the Redis cache. - - Return True if the key exists, otherwise False. - """ - await self._validate_cache() - key = self._key_to_typestring(key) - exists = await self.bot.redis_session.hexists(self._namespace, key) - - log.trace(f"Testing if {key} exists in the RedisCache - Result is {exists}") - return exists - - async def items(self) -> ItemsView: - """ - Fetch all the key/value pairs in the cache. - - Returns a normal ItemsView, like you would get from dict.items(). - - Keep in mind that these items are just a _copy_ of the data in the - RedisCache - any changes you make to them will not be reflected - into the RedisCache itself. If you want to change these, you need - to make a .set call. - - Example: - items = await my_cache.items() - for key, value in items: - # Iterate like a normal dictionary - """ - await self._validate_cache() - items = self._dict_from_typestring( - await self.bot.redis_session.hgetall(self._namespace) - ).items() - - log.trace(f"Retrieving all key/value pairs from cache, total of {len(items)} items.") - return items - - async def length(self) -> int: - """Return the number of items in the Redis cache.""" - await self._validate_cache() - number_of_items = await self.bot.redis_session.hlen(self._namespace) - log.trace(f"Returning length. Result is {number_of_items}.") - return number_of_items - - async def to_dict(self) -> Dict: - """Convert to dict and return.""" - return {key: value for key, value in await self.items()} - - async def clear(self) -> None: - """Deletes the entire hash from the Redis cache.""" - await self._validate_cache() - log.trace("Clearing the cache of all key/value pairs.") - await self.bot.redis_session.delete(self._namespace) - - async def pop(self, key: RedisKeyType, default: Optional[RedisValueType] = None) -> RedisValueType: - """Get the item, remove it from the cache, and provide a default if not found.""" - log.trace(f"Attempting to pop {key}.") - value = await self.get(key, default) - - log.trace( - f"Attempting to delete item with key '{key}' from the cache. " - "If this key doesn't exist, nothing will happen." - ) - await self.delete(key) - - return value - - async def update(self, items: Dict[RedisKeyType, RedisValueType]) -> None: - """ - Update the Redis cache with multiple values. - - This works exactly like dict.update from a normal dictionary. You pass - a dictionary with one or more key/value pairs into this method. If the keys - do not exist in the RedisCache, they are created. If they do exist, the values - are updated with the new ones from `items`. - - Please note that keys and the values in the `items` dictionary - must consist of valid RedisKeyTypes and RedisValueTypes. - """ - await self._validate_cache() - log.trace(f"Updating the cache with the following items:\n{items}") - await self.bot.redis_session.hmset_dict(self._namespace, self._dict_to_typestring(items)) - - async def increment(self, key: RedisKeyType, amount: Optional[int, float] = 1) -> None: - """ - Increment the value by `amount`. - - This works for both floats and ints, but will raise a TypeError - if you try to do it for any other type of value. - - This also supports negative amounts, although it would provide better - readability to use .decrement() for that. - """ - log.trace(f"Attempting to increment/decrement the value with the key {key} by {amount}.") - - # We initialize the lock here, because we need to ensure we get it - # running on the same loop as the calling coroutine. - # - # If we initialized the lock in the __init__, the loop that the coroutine this method - # would be called from might not exist yet, and so the lock would be on a different - # loop, which would raise RuntimeErrors. - if self._increment_lock is None: - self._increment_lock = asyncio.Lock() - - # Since this has several API calls, we need a lock to prevent race conditions - async with self._increment_lock: - value = await self.get(key) - - # Can't increment a non-existing value - if value is None: - error_message = "The provided key does not exist!" - log.error(error_message) - raise KeyError(error_message) - - # If it does exist, and it's an int or a float, increment and set it. - if isinstance(value, int) or isinstance(value, float): - value += amount - await self.set(key, value) - else: - error_message = "You may only increment or decrement values that are integers or floats." - log.error(error_message) - raise TypeError(error_message) - - async def decrement(self, key: RedisKeyType, amount: Optional[int, float] = 1) -> None: - """ - Decrement the value by `amount`. - - Basically just does the opposite of .increment. - """ - await self.increment(key, -amount) diff --git a/tests/bot/utils/test_redis_cache.py b/tests/bot/utils/test_redis_cache.py deleted file mode 100644 index a2f0fe55d..000000000 --- a/tests/bot/utils/test_redis_cache.py +++ /dev/null @@ -1,265 +0,0 @@ -import asyncio -import unittest - -import fakeredis.aioredis - -from bot.utils import RedisCache -from bot.utils.redis_cache import NoBotInstanceError, NoNamespaceError, NoParentInstanceError -from tests import helpers - - -class RedisCacheTests(unittest.IsolatedAsyncioTestCase): - """Tests the RedisCache class from utils.redis_dict.py.""" - - async def asyncSetUp(self): # noqa: N802 - """Sets up the objects that only have to be initialized once.""" - self.bot = helpers.MockBot() - self.bot.redis_session = await fakeredis.aioredis.create_redis_pool() - - # Okay, so this is necessary so that we can create a clean new - # class for every test method, and we want that because it will - # ensure we get a fresh loop, which is necessary for test_increment_lock - # to be able to pass. - class DummyCog: - """A dummy cog, for dummies.""" - - redis = RedisCache() - - def __init__(self, bot: helpers.MockBot): - self.bot = bot - - self.cog = DummyCog(self.bot) - - await self.cog.redis.clear() - - def test_class_attribute_namespace(self): - """Test that RedisDict creates a namespace automatically for class attributes.""" - self.assertEqual(self.cog.redis._namespace, "DummyCog.redis") - - async def test_class_attribute_required(self): - """Test that errors are raised when not assigned as a class attribute.""" - bad_cache = RedisCache() - self.assertIs(bad_cache._namespace, None) - - with self.assertRaises(RuntimeError): - await bad_cache.set("test", "me_up_deadman") - - async def test_set_get_item(self): - """Test that users can set and get items from the RedisDict.""" - test_cases = ( - ('favorite_fruit', 'melon'), - ('favorite_number', 86), - ('favorite_fraction', 86.54), - ('favorite_boolean', False), - ('other_boolean', True), - ) - - # Test that we can get and set different types. - for test in test_cases: - await self.cog.redis.set(*test) - self.assertEqual(await self.cog.redis.get(test[0]), test[1]) - - # Test that .get allows a default value - self.assertEqual(await self.cog.redis.get('favorite_nothing', "bearclaw"), "bearclaw") - - async def test_set_item_type(self): - """Test that .set rejects keys and values that are not permitted.""" - fruits = ["lemon", "melon", "apple"] - - with self.assertRaises(TypeError): - await self.cog.redis.set(fruits, "nice") - - with self.assertRaises(TypeError): - await self.cog.redis.set(4.23, "nice") - - async def test_delete_item(self): - """Test that .delete allows us to delete stuff from the RedisCache.""" - # Add an item and verify that it gets added - await self.cog.redis.set("internet", "firetruck") - self.assertEqual(await self.cog.redis.get("internet"), "firetruck") - - # Delete that item and verify that it gets deleted - await self.cog.redis.delete("internet") - self.assertIs(await self.cog.redis.get("internet"), None) - - async def test_contains(self): - """Test that we can check membership with .contains.""" - await self.cog.redis.set('favorite_country', "Burkina Faso") - - self.assertIs(await self.cog.redis.contains('favorite_country'), True) - self.assertIs(await self.cog.redis.contains('favorite_dentist'), False) - - async def test_items(self): - """Test that the RedisDict can be iterated.""" - # Set up our test cases in the Redis cache - test_cases = [ - ('favorite_turtle', 'Donatello'), - ('second_favorite_turtle', 'Leonardo'), - ('third_favorite_turtle', 'Raphael'), - ] - for key, value in test_cases: - await self.cog.redis.set(key, value) - - # Consume the AsyncIterator into a regular list, easier to compare that way. - redis_items = [item for item in await self.cog.redis.items()] - - # These sequences are probably in the same order now, but probably - # isn't good enough for tests. Let's not rely on .hgetall always - # returning things in sequence, and just sort both lists to be safe. - redis_items = sorted(redis_items) - test_cases = sorted(test_cases) - - # If these are equal now, everything works fine. - self.assertSequenceEqual(test_cases, redis_items) - - async def test_length(self): - """Test that we can get the correct .length from the RedisDict.""" - await self.cog.redis.set('one', 1) - await self.cog.redis.set('two', 2) - await self.cog.redis.set('three', 3) - self.assertEqual(await self.cog.redis.length(), 3) - - await self.cog.redis.set('four', 4) - self.assertEqual(await self.cog.redis.length(), 4) - - async def test_to_dict(self): - """Test that the .to_dict method returns a workable dictionary copy.""" - copy = await self.cog.redis.to_dict() - local_copy = {key: value for key, value in await self.cog.redis.items()} - self.assertIs(type(copy), dict) - self.assertDictEqual(copy, local_copy) - - async def test_clear(self): - """Test that the .clear method removes the entire hash.""" - await self.cog.redis.set('teddy', 'with me') - await self.cog.redis.set('in my dreams', 'you have a weird hat') - self.assertEqual(await self.cog.redis.length(), 2) - - await self.cog.redis.clear() - self.assertEqual(await self.cog.redis.length(), 0) - - async def test_pop(self): - """Test that we can .pop an item from the RedisDict.""" - await self.cog.redis.set('john', 'was afraid') - - self.assertEqual(await self.cog.redis.pop('john'), 'was afraid') - self.assertEqual(await self.cog.redis.pop('pete', 'breakneck'), 'breakneck') - self.assertEqual(await self.cog.redis.length(), 0) - - async def test_update(self): - """Test that we can .update the RedisDict with multiple items.""" - await self.cog.redis.set("reckfried", "lona") - await self.cog.redis.set("bel air", "prince") - await self.cog.redis.update({ - "reckfried": "jona", - "mega": "hungry, though", - }) - - result = { - "reckfried": "jona", - "bel air": "prince", - "mega": "hungry, though", - } - self.assertDictEqual(await self.cog.redis.to_dict(), result) - - def test_typestring_conversion(self): - """Test the typestring-related helper functions.""" - conversion_tests = ( - (12, "i|12"), - (12.4, "f|12.4"), - ("cowabunga", "s|cowabunga"), - ) - - # Test conversion to typestring - for _input, expected in conversion_tests: - self.assertEqual(self.cog.redis._value_to_typestring(_input), expected) - - # Test conversion from typestrings - for _input, expected in conversion_tests: - self.assertEqual(self.cog.redis._value_from_typestring(expected), _input) - - # Test that exceptions are raised on invalid input - with self.assertRaises(TypeError): - self.cog.redis._value_to_typestring(["internet"]) - self.cog.redis._value_from_typestring("o|firedog") - - async def test_increment_decrement(self): - """Test .increment and .decrement methods.""" - await self.cog.redis.set("entropic", 5) - await self.cog.redis.set("disentropic", 12.5) - - # Test default increment - await self.cog.redis.increment("entropic") - self.assertEqual(await self.cog.redis.get("entropic"), 6) - - # Test default decrement - await self.cog.redis.decrement("entropic") - self.assertEqual(await self.cog.redis.get("entropic"), 5) - - # Test float increment with float - await self.cog.redis.increment("disentropic", 2.0) - self.assertEqual(await self.cog.redis.get("disentropic"), 14.5) - - # Test float increment with int - await self.cog.redis.increment("disentropic", 2) - self.assertEqual(await self.cog.redis.get("disentropic"), 16.5) - - # Test negative increments, because why not. - await self.cog.redis.increment("entropic", -5) - self.assertEqual(await self.cog.redis.get("entropic"), 0) - - # Negative decrements? Sure. - await self.cog.redis.decrement("entropic", -5) - self.assertEqual(await self.cog.redis.get("entropic"), 5) - - # What about if we use a negative float to decrement an int? - # This should convert the type into a float. - await self.cog.redis.decrement("entropic", -2.5) - self.assertEqual(await self.cog.redis.get("entropic"), 7.5) - - # Let's test that they raise the right errors - with self.assertRaises(KeyError): - await self.cog.redis.increment("doesn't_exist!") - - await self.cog.redis.set("stringthing", "stringthing") - with self.assertRaises(TypeError): - await self.cog.redis.increment("stringthing") - - async def test_increment_lock(self): - """Test that we can't produce a race condition in .increment.""" - await self.cog.redis.set("test_key", 0) - tasks = [] - - # Increment this a lot in different tasks - for _ in range(100): - task = asyncio.create_task( - self.cog.redis.increment("test_key", 1) - ) - tasks.append(task) - await asyncio.gather(*tasks) - - # Confirm that the value has been incremented the exact right number of times. - value = await self.cog.redis.get("test_key") - self.assertEqual(value, 100) - - async def test_exceptions_raised(self): - """Testing that the various RuntimeErrors are reachable.""" - class MyCog: - cache = RedisCache() - - def __init__(self): - self.other_cache = RedisCache() - - cog = MyCog() - - # Raises "No Bot instance" - with self.assertRaises(NoBotInstanceError): - await cog.cache.get("john") - - # Raises "RedisCache has no namespace" - with self.assertRaises(NoNamespaceError): - await cog.other_cache.get("was") - - # Raises "You must access the RedisCache instance through the cog instance" - with self.assertRaises(NoParentInstanceError): - await MyCog.cache.get("afraid") -- cgit v1.2.3 From 0c5c472362d2891853bb82b225aa86da74d597e2 Mon Sep 17 00:00:00 2001 From: Sebastiaan Zeeff Date: Sun, 20 Sep 2020 12:43:00 +0200 Subject: Remove unit tests for duck pond I've removed the unit tests for duckpond in concordance with the new policy for writing unit tests for the bot The tests were unnecessarily complicated, difficult to maintain, and slowed development. Signed-off-by: Sebastiaan Zeeff --- tests/bot/cogs/test_duck_pond.py | 548 --------------------------------------- 1 file changed, 548 deletions(-) delete mode 100644 tests/bot/cogs/test_duck_pond.py (limited to 'tests') diff --git a/tests/bot/cogs/test_duck_pond.py b/tests/bot/cogs/test_duck_pond.py deleted file mode 100644 index cfe10aebf..000000000 --- a/tests/bot/cogs/test_duck_pond.py +++ /dev/null @@ -1,548 +0,0 @@ -import asyncio -import logging -import typing -import unittest -from unittest.mock import AsyncMock, MagicMock, patch - -import discord - -from bot import constants -from bot.cogs import duck_pond -from tests import base -from tests import helpers - -MODULE_PATH = "bot.cogs.duck_pond" - - -class DuckPondTests(base.LoggingTestsMixin, unittest.IsolatedAsyncioTestCase): - """Tests for DuckPond functionality.""" - - @classmethod - def setUpClass(cls): - """Sets up the objects that only have to be initialized once.""" - cls.nonstaff_member = helpers.MockMember(name="Non-staffer") - - cls.staff_role = helpers.MockRole(name="Staff role", id=constants.STAFF_ROLES[0]) - cls.staff_member = helpers.MockMember(name="staffer", roles=[cls.staff_role]) - - cls.checkmark_emoji = "\N{White Heavy Check Mark}" - cls.thumbs_up_emoji = "\N{Thumbs Up Sign}" - cls.unicode_duck_emoji = "\N{Duck}" - cls.duck_pond_emoji = helpers.MockPartialEmoji(id=constants.DuckPond.custom_emojis[0]) - cls.non_duck_custom_emoji = helpers.MockPartialEmoji(id=123) - - def setUp(self): - """Sets up the objects that need to be refreshed before each test.""" - self.bot = helpers.MockBot(user=helpers.MockMember(id=46692)) - self.cog = duck_pond.DuckPond(bot=self.bot) - - def test_duck_pond_correctly_initializes(self): - """`__init__ should set `bot` and `webhook_id` attributes and schedule `fetch_webhook`.""" - bot = helpers.MockBot() - cog = MagicMock() - - duck_pond.DuckPond.__init__(cog, bot) - - self.assertEqual(cog.bot, bot) - self.assertEqual(cog.webhook_id, constants.Webhooks.duck_pond) - bot.loop.create_task.assert_called_once_with(cog.fetch_webhook()) - - def test_fetch_webhook_succeeds_without_connectivity_issues(self): - """The `fetch_webhook` method waits until `READY` event and sets the `webhook` attribute.""" - self.bot.fetch_webhook.return_value = "dummy webhook" - self.cog.webhook_id = 1 - - asyncio.run(self.cog.fetch_webhook()) - - self.bot.wait_until_guild_available.assert_called_once() - self.bot.fetch_webhook.assert_called_once_with(1) - self.assertEqual(self.cog.webhook, "dummy webhook") - - def test_fetch_webhook_logs_when_unable_to_fetch_webhook(self): - """The `fetch_webhook` method should log an exception when it fails to fetch the webhook.""" - self.bot.fetch_webhook.side_effect = discord.HTTPException(response=MagicMock(), message="Not found.") - self.cog.webhook_id = 1 - - log = logging.getLogger('bot.cogs.duck_pond') - with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: - asyncio.run(self.cog.fetch_webhook()) - - self.bot.wait_until_guild_available.assert_called_once() - self.bot.fetch_webhook.assert_called_once_with(1) - - self.assertEqual(len(log_watcher.records), 1) - - record = log_watcher.records[0] - self.assertEqual(record.levelno, logging.ERROR) - - def test_is_staff_returns_correct_values_based_on_instance_passed(self): - """The `is_staff` method should return correct values based on the instance passed.""" - test_cases = ( - (helpers.MockUser(name="User instance"), False), - (helpers.MockMember(name="Member instance without staff role"), False), - (helpers.MockMember(name="Member instance with staff role", roles=[self.staff_role]), True) - ) - - for user, expected_return in test_cases: - actual_return = self.cog.is_staff(user) - with self.subTest(user_type=user.name, expected_return=expected_return, actual_return=actual_return): - self.assertEqual(expected_return, actual_return) - - async def test_has_green_checkmark_correctly_detects_presence_of_green_checkmark_emoji(self): - """The `has_green_checkmark` method should only return `True` if one is present.""" - test_cases = ( - ( - "No reactions", helpers.MockMessage(), False - ), - ( - "No green check mark reactions", - helpers.MockMessage(reactions=[ - helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), - helpers.MockReaction(emoji=self.thumbs_up_emoji, users=[self.bot.user]) - ]), - False - ), - ( - "Green check mark reaction, but not from our bot", - helpers.MockMessage(reactions=[ - helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), - helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member]) - ]), - False - ), - ( - "Green check mark reaction, with one from the bot", - helpers.MockMessage(reactions=[ - helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.bot.user]), - helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.staff_member, self.bot.user]) - ]), - True - ) - ) - - for description, message, expected_return in test_cases: - actual_return = await self.cog.has_green_checkmark(message) - with self.subTest( - test_case=description, - expected_return=expected_return, - actual_return=actual_return - ): - self.assertEqual(expected_return, actual_return) - - def _get_reaction( - self, - emoji: typing.Union[str, helpers.MockEmoji], - staff: int = 0, - nonstaff: int = 0 - ) -> helpers.MockReaction: - staffers = [helpers.MockMember(roles=[self.staff_role]) for _ in range(staff)] - nonstaffers = [helpers.MockMember() for _ in range(nonstaff)] - return helpers.MockReaction(emoji=emoji, users=staffers + nonstaffers) - - async def test_count_ducks_correctly_counts_the_number_of_eligible_duck_emojis(self): - """The `count_ducks` method should return the number of unique staffers who gave a duck.""" - test_cases = ( - # Simple test cases - # A message without reactions should return 0 - ( - "No reactions", - helpers.MockMessage(), - 0 - ), - # A message with a non-duck reaction from a non-staffer should return 0 - ( - "Non-duck reaction from non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.thumbs_up_emoji, nonstaff=1)]), - 0 - ), - # A message with a non-duck reaction from a staffer should return 0 - ( - "Non-duck reaction from staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.non_duck_custom_emoji, staff=1)]), - 0 - ), - # A message with a non-duck reaction from a non-staffer and staffer should return 0 - ( - "Non-duck reaction from staffer + non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.thumbs_up_emoji, staff=1, nonstaff=1)]), - 0 - ), - # A message with a unicode duck reaction from a non-staffer should return 0 - ( - "Unicode Duck Reaction from non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, nonstaff=1)]), - 0 - ), - # A message with a unicode duck reaction from a staffer should return 1 - ( - "Unicode Duck Reaction from staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, staff=1)]), - 1 - ), - # A message with a unicode duck reaction from a non-staffer and staffer should return 1 - ( - "Unicode Duck Reaction from staffer + non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.unicode_duck_emoji, staff=1, nonstaff=1)]), - 1 - ), - # A message with a duckpond duck reaction from a non-staffer should return 0 - ( - "Duckpond Duck Reaction from non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, nonstaff=1)]), - 0 - ), - # A message with a duckpond duck reaction from a staffer should return 1 - ( - "Duckpond Duck Reaction from staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=1)]), - 1 - ), - # A message with a duckpond duck reaction from a non-staffer and staffer should return 1 - ( - "Duckpond Duck Reaction from staffer + non-staffer", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=1, nonstaff=1)]), - 1 - ), - - # Complex test cases - # A message with duckpond duck reactions from 3 staffers and 2 non-staffers returns 3 - ( - "Duckpond Duck Reaction from 3 staffers + 2 non-staffers", - helpers.MockMessage(reactions=[self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2)]), - 3 - ), - # A staffer with multiple duck reactions only counts once - ( - "Two different duck reactions from the same staffer", - helpers.MockMessage( - reactions=[ - helpers.MockReaction(emoji=self.duck_pond_emoji, users=[self.staff_member]), - helpers.MockReaction(emoji=self.unicode_duck_emoji, users=[self.staff_member]), - ] - ), - 1 - ), - # A non-string emoji does not count (to test the `isinstance(reaction.emoji, str)` elif) - ( - "Reaction with non-Emoji/str emoij from 3 staffers + 2 non-staffers", - helpers.MockMessage(reactions=[self._get_reaction(emoji=100, staff=3, nonstaff=2)]), - 0 - ), - # We correctly sum when multiple reactions are provided. - ( - "Duckpond Duck Reaction from 3 staffers + 2 non-staffers", - helpers.MockMessage( - reactions=[ - self._get_reaction(emoji=self.duck_pond_emoji, staff=3, nonstaff=2), - self._get_reaction(emoji=self.unicode_duck_emoji, staff=4, nonstaff=9), - ] - ), - 3 + 4 - ), - ) - - for description, message, expected_count in test_cases: - actual_count = await self.cog.count_ducks(message) - with self.subTest(test_case=description, expected_count=expected_count, actual_count=actual_count): - self.assertEqual(expected_count, actual_count) - - async def test_relay_message_correctly_relays_content_and_attachments(self): - """The `relay_message` method should correctly relay message content and attachments.""" - send_webhook_path = f"{MODULE_PATH}.send_webhook" - send_attachments_path = f"{MODULE_PATH}.send_attachments" - author = MagicMock( - display_name="x", - avatar_url="https://" - ) - - self.cog.webhook = helpers.MockAsyncWebhook() - - test_values = ( - (helpers.MockMessage(author=author, clean_content="", attachments=[]), False, False), - (helpers.MockMessage(author=author, clean_content="message", attachments=[]), True, False), - (helpers.MockMessage(author=author, clean_content="", attachments=["attachment"]), False, True), - (helpers.MockMessage(author=author, clean_content="message", attachments=["attachment"]), True, True), - ) - - for message, expect_webhook_call, expect_attachment_call in test_values: - with patch(send_webhook_path, new_callable=AsyncMock) as send_webhook: - with patch(send_attachments_path, new_callable=AsyncMock) as send_attachments: - with self.subTest(clean_content=message.clean_content, attachments=message.attachments): - await self.cog.relay_message(message) - - self.assertEqual(expect_webhook_call, send_webhook.called) - self.assertEqual(expect_attachment_call, send_attachments.called) - - message.add_reaction.assert_called_once_with(self.checkmark_emoji) - - @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock) - async def test_relay_message_handles_irretrievable_attachment_exceptions(self, send_attachments): - """The `relay_message` method should handle irretrievable attachments.""" - message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) - side_effects = (discord.errors.Forbidden(MagicMock(), ""), discord.errors.NotFound(MagicMock(), "")) - - self.cog.webhook = helpers.MockAsyncWebhook() - log = logging.getLogger("bot.cogs.duck_pond") - - for side_effect in side_effects: # pragma: no cover - send_attachments.side_effect = side_effect - with patch(f"{MODULE_PATH}.send_webhook", new_callable=AsyncMock) as send_webhook: - with self.subTest(side_effect=type(side_effect).__name__): - with self.assertNotLogs(logger=log, level=logging.ERROR): - await self.cog.relay_message(message) - - self.assertEqual(send_webhook.call_count, 2) - - @patch(f"{MODULE_PATH}.send_webhook", new_callable=AsyncMock) - @patch(f"{MODULE_PATH}.send_attachments", new_callable=AsyncMock) - async def test_relay_message_handles_attachment_http_error(self, send_attachments, send_webhook): - """The `relay_message` method should handle irretrievable attachments.""" - message = helpers.MockMessage(clean_content="message", attachments=["attachment"]) - - self.cog.webhook = helpers.MockAsyncWebhook() - log = logging.getLogger("bot.cogs.duck_pond") - - side_effect = discord.HTTPException(MagicMock(), "") - send_attachments.side_effect = side_effect - with self.subTest(side_effect=type(side_effect).__name__): - with self.assertLogs(logger=log, level=logging.ERROR) as log_watcher: - await self.cog.relay_message(message) - - send_webhook.assert_called_once_with( - webhook=self.cog.webhook, - content=message.clean_content, - username=message.author.display_name, - avatar_url=message.author.avatar_url - ) - - self.assertEqual(len(log_watcher.records), 1) - - record = log_watcher.records[0] - self.assertEqual(record.levelno, logging.ERROR) - - def _mock_payload(self, label: str, is_custom_emoji: bool, id_: int, emoji_name: str): - """Creates a mock `on_raw_reaction_add` payload with the specified emoji data.""" - payload = MagicMock(name=label) - payload.emoji.is_custom_emoji.return_value = is_custom_emoji - payload.emoji.id = id_ - payload.emoji.name = emoji_name - return payload - - async def test_payload_has_duckpond_emoji_correctly_detects_relevant_emojis(self): - """The `on_raw_reaction_add` event handler should ignore irrelevant emojis.""" - test_values = ( - # Custom Emojis - ( - self._mock_payload( - label="Custom Duckpond Emoji", - is_custom_emoji=True, - id_=constants.DuckPond.custom_emojis[0], - emoji_name="" - ), - True - ), - ( - self._mock_payload( - label="Custom Non-Duckpond Emoji", - is_custom_emoji=True, - id_=123, - emoji_name="" - ), - False - ), - # Unicode Emojis - ( - self._mock_payload( - label="Unicode Duck Emoji", - is_custom_emoji=False, - id_=1, - emoji_name=self.unicode_duck_emoji - ), - True - ), - ( - self._mock_payload( - label="Unicode Non-Duck Emoji", - is_custom_emoji=False, - id_=1, - emoji_name=self.thumbs_up_emoji - ), - False - ), - ) - - for payload, expected_return in test_values: - actual_return = self.cog._payload_has_duckpond_emoji(payload) - with self.subTest(case=payload._mock_name, expected_return=expected_return, actual_return=actual_return): - self.assertEqual(expected_return, actual_return) - - @patch(f"{MODULE_PATH}.discord.utils.get") - @patch(f"{MODULE_PATH}.DuckPond._payload_has_duckpond_emoji", new=MagicMock(return_value=False)) - def test_on_raw_reaction_add_returns_early_with_payload_without_duck_emoji(self, utils_get): - """The `on_raw_reaction_add` method should return early if the payload does not contain a duck emoji.""" - self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_add(payload=MagicMock()))) - - # Ensure we've returned before making an unnecessary API call in the lines of code after the emoji check - utils_get.assert_not_called() - - def _raw_reaction_mocks(self, channel_id, message_id, user_id): - """Sets up mocks for tests of the `on_raw_reaction_add` event listener.""" - channel = helpers.MockTextChannel(id=channel_id) - self.bot.get_all_channels.return_value = (channel,) - - message = helpers.MockMessage(id=message_id) - - channel.fetch_message.return_value = message - - member = helpers.MockMember(id=user_id, roles=[self.staff_role]) - message.guild.members = (member,) - - payload = MagicMock(channel_id=channel_id, message_id=message_id, user_id=user_id) - - return channel, message, member, payload - - async def test_on_raw_reaction_add_returns_for_bot_and_non_staff_members(self): - """The `on_raw_reaction_add` event handler should return for bot users or non-staff members.""" - channel_id = 1234 - message_id = 2345 - user_id = 3456 - - channel, message, _, payload = self._raw_reaction_mocks(channel_id, message_id, user_id) - - test_cases = ( - ("non-staff member", helpers.MockMember(id=user_id)), - ("bot staff member", helpers.MockMember(id=user_id, roles=[self.staff_role], bot=True)), - ) - - payload.emoji = self.duck_pond_emoji - - for description, member in test_cases: - message.guild.members = (member, ) - with self.subTest(test_case=description), patch(f"{MODULE_PATH}.DuckPond.has_green_checkmark") as checkmark: - checkmark.side_effect = AssertionError( - "Expected method to return before calling `self.has_green_checkmark`." - ) - self.assertIsNone(await self.cog.on_raw_reaction_add(payload)) - - # Check that we did make it past the payload checks - channel.fetch_message.assert_called_once() - channel.fetch_message.reset_mock() - - @patch(f"{MODULE_PATH}.DuckPond.is_staff") - @patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) - def test_on_raw_reaction_add_returns_on_message_with_green_checkmark_placed_by_bot(self, count_ducks, is_staff): - """The `on_raw_reaction_add` event should return when the message has a green check mark placed by the bot.""" - channel_id = 31415926535 - message_id = 27182818284 - user_id = 16180339887 - - channel, message, member, payload = self._raw_reaction_mocks(channel_id, message_id, user_id) - - payload.emoji = helpers.MockPartialEmoji(name=self.unicode_duck_emoji) - payload.emoji.is_custom_emoji.return_value = False - - message.reactions = [helpers.MockReaction(emoji=self.checkmark_emoji, users=[self.bot.user])] - - is_staff.return_value = True - count_ducks.side_effect = AssertionError("Expected method to return before calling `self.count_ducks`") - - self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_add(payload))) - - # Assert that we've made it past `self.is_staff` - is_staff.assert_called_once() - - async def test_on_raw_reaction_add_does_not_relay_below_duck_threshold(self): - """The `on_raw_reaction_add` listener should not relay messages or attachments below the duck threshold.""" - test_cases = ( - (constants.DuckPond.threshold - 1, False), - (constants.DuckPond.threshold, True), - (constants.DuckPond.threshold + 1, True), - ) - - channel, message, member, payload = self._raw_reaction_mocks(channel_id=3, message_id=4, user_id=5) - - payload.emoji = self.duck_pond_emoji - - for duck_count, should_relay in test_cases: - with patch(f"{MODULE_PATH}.DuckPond.relay_message", new_callable=AsyncMock) as relay_message: - with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks: - count_ducks.return_value = duck_count - with self.subTest(duck_count=duck_count, should_relay=should_relay): - await self.cog.on_raw_reaction_add(payload) - - # Confirm that we've made it past counting - count_ducks.assert_called_once() - - # Did we relay a message? - has_relayed = relay_message.called - self.assertEqual(has_relayed, should_relay) - - if should_relay: - relay_message.assert_called_once_with(message) - - async def test_on_raw_reaction_remove_prevents_removal_of_green_checkmark_depending_on_the_duck_count(self): - """The `on_raw_reaction_remove` listener prevents removal of the check mark on messages with enough ducks.""" - checkmark = helpers.MockPartialEmoji(name=self.checkmark_emoji) - - message = helpers.MockMessage(id=1234) - - channel = helpers.MockTextChannel(id=98765) - channel.fetch_message.return_value = message - - self.bot.get_all_channels.return_value = (channel, ) - - payload = MagicMock(channel_id=channel.id, message_id=message.id, emoji=checkmark) - - test_cases = ( - (constants.DuckPond.threshold - 1, False), - (constants.DuckPond.threshold, True), - (constants.DuckPond.threshold + 1, True), - ) - for duck_count, should_re_add_checkmark in test_cases: - with patch(f"{MODULE_PATH}.DuckPond.count_ducks", new_callable=AsyncMock) as count_ducks: - count_ducks.return_value = duck_count - with self.subTest(duck_count=duck_count, should_re_add_checkmark=should_re_add_checkmark): - await self.cog.on_raw_reaction_remove(payload) - - # Check if we fetched the message - channel.fetch_message.assert_called_once_with(message.id) - - # Check if we actually counted the number of ducks - count_ducks.assert_called_once_with(message) - - has_re_added_checkmark = message.add_reaction.called - self.assertEqual(should_re_add_checkmark, has_re_added_checkmark) - - if should_re_add_checkmark: - message.add_reaction.assert_called_once_with(self.checkmark_emoji) - message.add_reaction.reset_mock() - - # reset mocks - channel.fetch_message.reset_mock() - message.reset_mock() - - def test_on_raw_reaction_remove_ignores_removal_of_non_checkmark_reactions(self): - """The `on_raw_reaction_remove` listener should ignore the removal of non-check mark emojis.""" - channel = helpers.MockTextChannel(id=98765) - - channel.fetch_message.side_effect = AssertionError( - "Expected method to return before calling `channel.fetch_message`" - ) - - self.bot.get_all_channels.return_value = (channel, ) - - payload = MagicMock(emoji=helpers.MockPartialEmoji(name=self.thumbs_up_emoji), channel_id=channel.id) - - self.assertIsNone(asyncio.run(self.cog.on_raw_reaction_remove(payload))) - - channel.fetch_message.assert_not_called() - - -class DuckPondSetupTests(unittest.TestCase): - """Tests setup of the `DuckPond` cog.""" - - def test_setup(self): - """Setup of the extension should call add_cog.""" - bot = helpers.MockBot() - duck_pond.setup(bot) - bot.add_cog.assert_called_once() -- cgit v1.2.3