From 1c569f2f38fe18d6210deec001046cf9ee68ea53 Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Sat, 18 Jul 2020 16:54:01 +0200 Subject: Remove AntiMalWare constants, use cache data. Also updates the tests for this cog. --- tests/bot/cogs/test_antimalware.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py index f219fc1ba..1e010d2ce 100644 --- a/tests/bot/cogs/test_antimalware.py +++ b/tests/bot/cogs/test_antimalware.py @@ -1,28 +1,33 @@ import unittest -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import AsyncMock, Mock from discord import NotFound from bot.cogs import antimalware -from bot.constants import AntiMalware as AntiMalwareConfig, Channels, STAFF_ROLES +from bot.constants import Channels, STAFF_ROLES from tests.helpers import MockAttachment, MockBot, MockMessage, MockRole -MODULE = "bot.cogs.antimalware" - -@patch(f"{MODULE}.AntiMalwareConfig.whitelist", new=[".first", ".second", ".third"]) class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): """Test the AntiMalware cog.""" def setUp(self): """Sets up fresh objects for each test.""" self.bot = MockBot() + self.bot.allow_deny_list_cache = { + "file_format.True": [ + {"content": ".first"}, + {"content": ".second"}, + {"content": ".third"} + ] + } self.cog = antimalware.AntiMalware(self.bot) self.message = MockMessage() + self.whitelist = [".first", ".second", ".third"] async def test_message_with_allowed_attachment(self): """Messages with allowed extensions should not be deleted""" - attachment = MockAttachment(filename=f"python{AntiMalwareConfig.whitelist[0]}") + attachment = MockAttachment(filename="python.first") self.message.attachments = [attachment] await self.cog.on_message(self.message) @@ -93,7 +98,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(embed.description, antimalware.TXT_EMBED_DESCRIPTION.format.return_value) antimalware.TXT_EMBED_DESCRIPTION.format.assert_called_with(cmd_channel_mention=cmd_channel.mention) - async def test_other_disallowed_extention_embed_description(self): + async def test_other_disallowed_extension_embed_description(self): """Test the description for a non .py/.txt disallowed extension.""" attachment = MockAttachment(filename="python.disallowed") self.message.attachments = [attachment] @@ -109,6 +114,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(embed.description, antimalware.DISALLOWED_EMBED_DESCRIPTION.format.return_value) antimalware.DISALLOWED_EMBED_DESCRIPTION.format.assert_called_with( + joined_whitelist=", ".join(self.whitelist), blocked_extensions_str=".disallowed", meta_channel_mention=meta_channel.mention ) @@ -135,7 +141,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): """The return value should include all non-whitelisted extensions.""" test_values = ( ([], []), - (AntiMalwareConfig.whitelist, []), + (self.whitelist, []), ([".first"], []), ([".first", ".disallowed"], [".disallowed"]), ([".disallowed"], [".disallowed"]), @@ -145,7 +151,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): for extensions, expected_disallowed_extensions in test_values: with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions): self.message.attachments = [MockAttachment(filename=f"filename{extension}") for extension in extensions] - disallowed_extensions = self.cog.get_disallowed_extensions(self.message) + disallowed_extensions = self.cog._get_disallowed_extensions(self.message) self.assertCountEqual(disallowed_extensions, expected_disallowed_extensions) -- cgit v1.2.3 From 3d5faa421756fadb42590db92e8fee64578390d4 Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Mon, 27 Jul 2020 10:26:10 +0200 Subject: Rename AllowDenyList to FilterLists --- bot/__main__.py | 2 +- bot/bot.py | 14 +-- bot/cogs/allow_deny_lists.py | 218 ------------------------------------- bot/cogs/antimalware.py | 2 +- bot/cogs/filter_lists.py | 218 +++++++++++++++++++++++++++++++++++++ bot/cogs/filtering.py | 16 +-- bot/converters.py | 10 +- tests/bot/cogs/test_antimalware.py | 2 +- 8 files changed, 241 insertions(+), 241 deletions(-) delete mode 100644 bot/cogs/allow_deny_lists.py create mode 100644 bot/cogs/filter_lists.py (limited to 'tests') diff --git a/bot/__main__.py b/bot/__main__.py index 932aa705c..c2271cd16 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -53,7 +53,7 @@ bot.load_extension("bot.cogs.verification") # Feature cogs bot.load_extension("bot.cogs.alias") -bot.load_extension("bot.cogs.allow_deny_lists") +bot.load_extension("bot.cogs.filter_lists") bot.load_extension("bot.cogs.defcon") bot.load_extension("bot.cogs.dm_relay") bot.load_extension("bot.cogs.duck_pond") diff --git a/bot/bot.py b/bot/bot.py index d834c151b..3dfb4e948 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -34,7 +34,7 @@ class Bot(commands.Bot): self.redis_ready = asyncio.Event() self.redis_closed = False self.api_client = api.APIClient(loop=self.loop) - self.allow_deny_list_cache = {} + self.filter_list_cache = {} self._connector = None self._resolver = None @@ -50,9 +50,9 @@ class Bot(commands.Bot): self.stats = AsyncStatsClient(self.loop, statsd_url, 8125, prefix="bot") - async def _cache_allow_deny_list_data(self) -> None: - """Cache all the data in the AllowDenyList on the site.""" - full_cache = await self.api_client.get('bot/allow_deny_lists') + async def _cache_filter_list_data(self) -> None: + """Cache all the data in the FilterList on the site.""" + full_cache = await self.api_client.get('bot/filter-lists') for item in full_cache: type_ = item.get("type") @@ -64,7 +64,7 @@ class Bot(commands.Bot): "created_at": item.get("created_at"), "updated_at": item.get("updated_at"), } - self.allow_deny_list_cache.setdefault(f"{type_}.{allowed}", []).append(metadata) + self.filter_list_cache.setdefault(f"{type_}.{allowed}", []).append(metadata) async def _create_redis_session(self) -> None: """ @@ -176,8 +176,8 @@ class Bot(commands.Bot): self.http_session = aiohttp.ClientSession(connector=self._connector) self.api_client.recreate(force=True, connector=self._connector) - # Build the AllowDenyList cache - self.loop.create_task(self._cache_allow_deny_list_data()) + # Build the FilterList cache + self.loop.create_task(self._cache_filter_list_data()) async def on_guild_available(self, guild: discord.Guild) -> None: """ diff --git a/bot/cogs/allow_deny_lists.py b/bot/cogs/allow_deny_lists.py deleted file mode 100644 index e28e32bd6..000000000 --- a/bot/cogs/allow_deny_lists.py +++ /dev/null @@ -1,218 +0,0 @@ -import logging -from typing import Optional - -from discord import Colour, Embed -from discord.ext.commands import BadArgument, Cog, Context, IDConverter, group - -from bot import constants -from bot.api import ResponseCodeError -from bot.bot import Bot -from bot.converters import ValidAllowDenyListType, ValidDiscordServerInvite -from bot.pagination import LinePaginator -from bot.utils.checks import with_role_check - -log = logging.getLogger(__name__) - - -class AllowDenyLists(Cog): - """Commands for blacklisting and whitelisting things.""" - - def __init__(self, bot: Bot) -> None: - self.bot = bot - - async def _add_data( - self, - ctx: Context, - allowed: bool, - list_type: ValidAllowDenyListType, - content: str, - comment: Optional[str] = None, - ) -> None: - """Add an item to an allow or denylist.""" - allow_type = "whitelist" if allowed else "blacklist" - - # If this is a server invite, we gotta validate it. - if list_type == "GUILD_INVITE": - log.trace(f"{content} is a guild invite, attempting to validate.") - validator = ValidDiscordServerInvite() - guild_data = await validator.convert(ctx, content) - - # If we make it this far without raising a BadArgument, the invite is - # valid. Let's convert the content to an ID. - log.trace(f"{content} validated as server invite. Converting to ID.") - content = guild_data.get("id") - - # Unless the user has specified another comment, let's - # use the server name as the comment so that the list - # of guild IDs will be more easily readable when we - # display it. - if not comment: - comment = guild_data.get("name") - - # Try to add the item to the database - log.trace(f"Trying to add the {content} item to the {list_type} {allow_type}") - payload = { - 'allowed': allowed, - 'type': list_type, - 'content': content, - 'comment': comment, - } - - try: - item = await self.bot.api_client.post( - "bot/allow_deny_lists", - json=payload - ) - except ResponseCodeError as e: - if e.status == 500: - await ctx.message.add_reaction("❌") - log.debug( - f"{ctx.author} tried to add data to a {allow_type}, but the API returned 500, " - "probably because the request violated the UniqueConstraint." - ) - raise BadArgument( - f"Unable to add the item to the {allow_type}. " - "The item probably already exists. Keep in mind that a " - "blacklist and a whitelist for the same item cannot co-exist, " - "and we do not permit any duplicates." - ) - raise - - # Insert the item into the cache - type_ = item.get("type") - allowed = item.get("allowed") - metadata = { - "content": item.get("content"), - "comment": item.get("comment"), - "id": item.get("id"), - "created_at": item.get("created_at"), - "updated_at": item.get("updated_at"), - } - self.bot.allow_deny_list_cache.setdefault(f"{type_}.{allowed}", []).append(metadata) - await ctx.message.add_reaction("✅") - - async def _delete_data(self, ctx: Context, allowed: bool, list_type: ValidAllowDenyListType, content: str) -> None: - """Remove an item from an allow or denylist.""" - item = None - allow_type = "whitelist" if allowed else "blacklist" - id_converter = IDConverter() - - # If this is a server invite, we need to convert it. - if list_type == "GUILD_INVITE" and not id_converter._get_id_match(content): - log.trace(f"{content} is a guild invite, attempting to validate.") - validator = ValidDiscordServerInvite() - guild_data = await validator.convert(ctx, content) - - # If we make it this far without raising a BadArgument, the invite is - # valid. Let's convert the content to an ID. - log.trace(f"{content} validated as server invite. Converting to ID.") - content = guild_data.get("id") - - # Find the content and delete it. - log.trace(f"Trying to delete the {content} item from the {list_type} {allow_type}") - for allow_list in self.bot.allow_deny_list_cache.get(f"{list_type}.{allowed}", []): - if content == allow_list.get("content"): - item = allow_list - break - - if item is not None: - await self.bot.api_client.delete( - f"bot/allow_deny_lists/{item.get('id')}" - ) - self.bot.allow_deny_list_cache[f"{list_type}.{allowed}"].remove(item) - await ctx.message.add_reaction("✅") - - async def _list_all_data(self, ctx: Context, allowed: bool, list_type: ValidAllowDenyListType) -> None: - """Paginate and display all items in an allow or denylist.""" - allow_type = "whitelist" if allowed else "blacklist" - result = self.bot.allow_deny_list_cache.get(f"{list_type}.{allowed}", []) - - # Build a list of lines we want to show in the paginator - lines = [] - for item in result: - line = f"• `{item.get('content')}`" - - if item.get("comment"): - line += f" - {item.get('comment')}" - - lines.append(line) - lines = sorted(lines) - - # Build the embed - list_type_plural = list_type.lower().replace("_", " ").title() + "s" - embed = Embed( - title=f"{allow_type.title()}ed {list_type_plural} ({len(result)} total)", - colour=Colour.blue() - ) - log.trace(f"Trying to list {len(result)} items from the {list_type.lower()} {allow_type}") - - if result: - await LinePaginator.paginate(lines, ctx, embed, max_lines=15, empty=False) - else: - embed.description = "Hmmm, seems like there's nothing here yet." - await ctx.send(embed=embed) - - @group(aliases=("allowlist", "allow", "al", "wl")) - async def whitelist(self, ctx: Context) -> None: - """Group for whitelisting commands.""" - if not ctx.invoked_subcommand: - await ctx.send_help(ctx.command) - - @group(aliases=("denylist", "deny", "bl", "dl")) - async def blacklist(self, ctx: Context) -> None: - """Group for blacklisting commands.""" - if not ctx.invoked_subcommand: - await ctx.send_help(ctx.command) - - @whitelist.command(name="add", aliases=("a", "set")) - async def allow_add( - self, - ctx: Context, - list_type: ValidAllowDenyListType, - content: str, - *, - comment: Optional[str] = None, - ) -> None: - """Add an item to the specified allowlist.""" - await self._add_data(ctx, True, list_type, content, comment) - - @blacklist.command(name="add", aliases=("a", "set")) - async def deny_add( - self, - ctx: Context, - list_type: ValidAllowDenyListType, - content: str, - *, - comment: Optional[str] = None, - ) -> None: - """Add an item to the specified denylist.""" - await self._add_data(ctx, False, list_type, content, comment) - - @whitelist.command(name="remove", aliases=("delete", "rm",)) - async def allow_delete(self, ctx: Context, list_type: ValidAllowDenyListType, content: str) -> None: - """Remove an item from the specified allowlist.""" - await self._delete_data(ctx, True, list_type, content) - - @blacklist.command(name="remove", aliases=("delete", "rm",)) - async def deny_delete(self, ctx: Context, list_type: ValidAllowDenyListType, content: str) -> None: - """Remove an item from the specified denylist.""" - await self._delete_data(ctx, False, list_type, content) - - @whitelist.command(name="get", aliases=("list", "ls", "fetch", "show")) - async def allow_get(self, ctx: Context, list_type: ValidAllowDenyListType) -> None: - """Get the contents of a specified allowlist.""" - await self._list_all_data(ctx, True, list_type) - - @blacklist.command(name="get", aliases=("list", "ls", "fetch", "show")) - async def deny_get(self, ctx: Context, list_type: ValidAllowDenyListType) -> None: - """Get the contents of a specified denylist.""" - await self._list_all_data(ctx, False, list_type) - - def cog_check(self, ctx: Context) -> bool: - """Only allow moderators to invoke the commands in this cog.""" - return with_role_check(ctx, *constants.MODERATION_ROLES) - - -def setup(bot: Bot) -> None: - """Load the AllowDenyLists cog.""" - bot.add_cog(AllowDenyLists(bot)) diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py index 5b56f937f..9a100b3fc 100644 --- a/bot/cogs/antimalware.py +++ b/bot/cogs/antimalware.py @@ -40,7 +40,7 @@ class AntiMalware(Cog): def _get_whitelisted_file_formats(self) -> list: """Get the file formats currently on the whitelist.""" - return [item['content'] for item in self.bot.allow_deny_list_cache['file_format.True']] + return [item['content'] for item in self.bot.filter_list_cache['file_format.True']] def _get_disallowed_extensions(self, message: Message) -> t.Iterable[str]: """Get an iterable containing all the disallowed extensions of attachments.""" diff --git a/bot/cogs/filter_lists.py b/bot/cogs/filter_lists.py new file mode 100644 index 000000000..d1db9830e --- /dev/null +++ b/bot/cogs/filter_lists.py @@ -0,0 +1,218 @@ +import logging +from typing import Optional + +from discord import Colour, Embed +from discord.ext.commands import BadArgument, Cog, Context, IDConverter, group + +from bot import constants +from bot.api import ResponseCodeError +from bot.bot import Bot +from bot.converters import ValidDiscordServerInvite, ValidFilterListType +from bot.pagination import LinePaginator +from bot.utils.checks import with_role_check + +log = logging.getLogger(__name__) + + +class FilterLists(Cog): + """Commands for blacklisting and whitelisting things.""" + + def __init__(self, bot: Bot) -> None: + self.bot = bot + + async def _add_data( + self, + ctx: Context, + allowed: bool, + list_type: ValidFilterListType, + content: str, + comment: Optional[str] = None, + ) -> None: + """Add an item to a filterlist.""" + allow_type = "whitelist" if allowed else "blacklist" + + # If this is a server invite, we gotta validate it. + if list_type == "GUILD_INVITE": + log.trace(f"{content} is a guild invite, attempting to validate.") + validator = ValidDiscordServerInvite() + guild_data = await validator.convert(ctx, content) + + # If we make it this far without raising a BadArgument, the invite is + # valid. Let's convert the content to an ID. + log.trace(f"{content} validated as server invite. Converting to ID.") + content = guild_data.get("id") + + # Unless the user has specified another comment, let's + # use the server name as the comment so that the list + # of guild IDs will be more easily readable when we + # display it. + if not comment: + comment = guild_data.get("name") + + # Try to add the item to the database + log.trace(f"Trying to add the {content} item to the {list_type} {allow_type}") + payload = { + 'allowed': allowed, + 'type': list_type, + 'content': content, + 'comment': comment, + } + + try: + item = await self.bot.api_client.post( + "bot/filter-lists", + json=payload + ) + except ResponseCodeError as e: + if e.status == 500: + await ctx.message.add_reaction("❌") + log.debug( + f"{ctx.author} tried to add data to a {allow_type}, but the API returned 500, " + "probably because the request violated the UniqueConstraint." + ) + raise BadArgument( + f"Unable to add the item to the {allow_type}. " + "The item probably already exists. Keep in mind that a " + "blacklist and a whitelist for the same item cannot co-exist, " + "and we do not permit any duplicates." + ) + raise + + # Insert the item into the cache + type_ = item.get("type") + allowed = item.get("allowed") + metadata = { + "content": item.get("content"), + "comment": item.get("comment"), + "id": item.get("id"), + "created_at": item.get("created_at"), + "updated_at": item.get("updated_at"), + } + self.bot.filter_list_cache.setdefault(f"{type_}.{allowed}", []).append(metadata) + await ctx.message.add_reaction("✅") + + async def _delete_data(self, ctx: Context, allowed: bool, list_type: ValidFilterListType, content: str) -> None: + """Remove an item from a filterlist.""" + item = None + allow_type = "whitelist" if allowed else "blacklist" + id_converter = IDConverter() + + # If this is a server invite, we need to convert it. + if list_type == "GUILD_INVITE" and not id_converter._get_id_match(content): + log.trace(f"{content} is a guild invite, attempting to validate.") + validator = ValidDiscordServerInvite() + guild_data = await validator.convert(ctx, content) + + # If we make it this far without raising a BadArgument, the invite is + # valid. Let's convert the content to an ID. + log.trace(f"{content} validated as server invite. Converting to ID.") + content = guild_data.get("id") + + # Find the content and delete it. + log.trace(f"Trying to delete the {content} item from the {list_type} {allow_type}") + for allow_list in self.bot.filter_list_cache.get(f"{list_type}.{allowed}", []): + if content == allow_list.get("content"): + item = allow_list + break + + if item is not None: + await self.bot.api_client.delete( + f"bot/filter-lists/{item.get('id')}" + ) + self.bot.filter_list_cache[f"{list_type}.{allowed}"].remove(item) + await ctx.message.add_reaction("✅") + + async def _list_all_data(self, ctx: Context, allowed: bool, list_type: ValidFilterListType) -> None: + """Paginate and display all items in a filterlist.""" + allow_type = "whitelist" if allowed else "blacklist" + result = self.bot.filter_list_cache.get(f"{list_type}.{allowed}", []) + + # Build a list of lines we want to show in the paginator + lines = [] + for item in result: + line = f"• `{item.get('content')}`" + + if item.get("comment"): + line += f" - {item.get('comment')}" + + lines.append(line) + lines = sorted(lines) + + # Build the embed + list_type_plural = list_type.lower().replace("_", " ").title() + "s" + embed = Embed( + title=f"{allow_type.title()}ed {list_type_plural} ({len(result)} total)", + colour=Colour.blue() + ) + log.trace(f"Trying to list {len(result)} items from the {list_type.lower()} {allow_type}") + + if result: + await LinePaginator.paginate(lines, ctx, embed, max_lines=15, empty=False) + else: + embed.description = "Hmmm, seems like there's nothing here yet." + await ctx.send(embed=embed) + + @group(aliases=("allowlist", "allow", "al", "wl")) + async def whitelist(self, ctx: Context) -> None: + """Group for whitelisting commands.""" + if not ctx.invoked_subcommand: + await ctx.send_help(ctx.command) + + @group(aliases=("denylist", "deny", "bl", "dl")) + async def blacklist(self, ctx: Context) -> None: + """Group for blacklisting commands.""" + if not ctx.invoked_subcommand: + await ctx.send_help(ctx.command) + + @whitelist.command(name="add", aliases=("a", "set")) + async def allow_add( + self, + ctx: Context, + list_type: ValidFilterListType, + content: str, + *, + comment: Optional[str] = None, + ) -> None: + """Add an item to the specified allowlist.""" + await self._add_data(ctx, True, list_type, content, comment) + + @blacklist.command(name="add", aliases=("a", "set")) + async def deny_add( + self, + ctx: Context, + list_type: ValidFilterListType, + content: str, + *, + comment: Optional[str] = None, + ) -> None: + """Add an item to the specified denylist.""" + await self._add_data(ctx, False, list_type, content, comment) + + @whitelist.command(name="remove", aliases=("delete", "rm",)) + async def allow_delete(self, ctx: Context, list_type: ValidFilterListType, content: str) -> None: + """Remove an item from the specified allowlist.""" + await self._delete_data(ctx, True, list_type, content) + + @blacklist.command(name="remove", aliases=("delete", "rm",)) + async def deny_delete(self, ctx: Context, list_type: ValidFilterListType, content: str) -> None: + """Remove an item from the specified denylist.""" + await self._delete_data(ctx, False, list_type, content) + + @whitelist.command(name="get", aliases=("list", "ls", "fetch", "show")) + async def allow_get(self, ctx: Context, list_type: ValidFilterListType) -> None: + """Get the contents of a specified allowlist.""" + await self._list_all_data(ctx, True, list_type) + + @blacklist.command(name="get", aliases=("list", "ls", "fetch", "show")) + async def deny_get(self, ctx: Context, list_type: ValidFilterListType) -> None: + """Get the contents of a specified denylist.""" + await self._list_all_data(ctx, False, list_type) + + def cog_check(self, ctx: Context) -> bool: + """Only allow moderators to invoke the commands in this cog.""" + return with_role_check(ctx, *constants.MODERATION_ROLES) + + +def setup(bot: Bot) -> None: + """Load the FilterLists cog.""" + bot.add_cog(FilterLists(bot)) diff --git a/bot/cogs/filtering.py b/bot/cogs/filtering.py index 8897cbaf9..652af5ff5 100644 --- a/bot/cogs/filtering.py +++ b/bot/cogs/filtering.py @@ -99,9 +99,9 @@ class Filtering(Cog): self.bot.loop.create_task(self.reschedule_offensive_msg_deletion()) - def _get_allowlist_items(self, list_type: str, *, allowed: bool, compiled: Optional[bool] = False) -> list: - """Fetch items from the allow_deny_list_cache.""" - items = self.bot.allow_deny_list_cache.get(f"{list_type.upper()}.{allowed}", []) + def _get_filterlist_items(self, list_type: str, *, allowed: bool, compiled: Optional[bool] = False) -> list: + """Fetch items from the filter_list_cache.""" + items = self.bot.filter_list_cache.get(f"{list_type.upper()}.{allowed}", []) if compiled: return [re.compile(fr'{item["content"]}', flags=re.IGNORECASE) for item in items] @@ -143,7 +143,7 @@ class Filtering(Cog): def get_name_matches(self, name: str) -> List[re.Match]: """Check bad words from passed string (name). Return list of matches.""" matches = [] - watchlist_patterns = self._get_allowlist_items('word_watchlist', allowed=False, compiled=True) + watchlist_patterns = self._get_filterlist_items('word_watchlist', allowed=False, compiled=True) for pattern in watchlist_patterns: if match := pattern.search(name): matches.append(match) @@ -408,7 +408,7 @@ class Filtering(Cog): if URL_RE.search(text): return False - watchlist_patterns = self._get_allowlist_items('word_watchlist', allowed=False, compiled=True) + watchlist_patterns = self._get_filterlist_items('word_watchlist', allowed=False, compiled=True) for pattern in watchlist_patterns: match = pattern.search(text) if match: @@ -420,7 +420,7 @@ class Filtering(Cog): return False text = text.lower() - domain_blacklist = self._get_allowlist_items("domain_name", allowed=False) + domain_blacklist = self._get_filterlist_items("domain_name", allowed=False) for url in domain_blacklist: if url.lower() in text: @@ -468,8 +468,8 @@ class Filtering(Cog): return True guild_id = guild.get("id") - guild_invite_whitelist = self._get_allowlist_items("guild_invite", allowed=True) - guild_invite_blacklist = self._get_allowlist_items("guild_invite", allowed=False) + guild_invite_whitelist = self._get_filterlist_items("guild_invite", allowed=True) + guild_invite_blacklist = self._get_filterlist_items("guild_invite", allowed=False) # Is this invite allowed? guild_partnered_or_verified = ( diff --git a/bot/converters.py b/bot/converters.py index 41cd3f3e5..158bf1a16 100644 --- a/bot/converters.py +++ b/bot/converters.py @@ -72,18 +72,18 @@ class ValidDiscordServerInvite(Converter): raise BadArgument("This does not appear to be a valid Discord server invite.") -class ValidAllowDenyListType(Converter): +class ValidFilterListType(Converter): """ - A converter that checks whether the given string is a valid AllowDenyList type. + A converter that checks whether the given string is a valid FilterList type. - Raises `BadArgument` if the argument is not a valid AllowDenyList type, and simply + Raises `BadArgument` if the argument is not a valid FilterList type, and simply passes through the given argument otherwise. """ async def convert(self, ctx: Context, list_type: str) -> str: - """Checks whether the given string is a valid AllowDenyList type.""" + """Checks whether the given string is a valid FilterList type.""" try: - valid_types = await ctx.bot.api_client.get('bot/allow_deny_lists/get_types') + valid_types = await ctx.bot.api_client.get('bot/filter-lists/get-types') except ResponseCodeError: raise BadArgument("Cannot validate list_type: Unable to fetch valid types from API.") diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py index 1e010d2ce..664fa8f19 100644 --- a/tests/bot/cogs/test_antimalware.py +++ b/tests/bot/cogs/test_antimalware.py @@ -14,7 +14,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): def setUp(self): """Sets up fresh objects for each test.""" self.bot = MockBot() - self.bot.allow_deny_list_cache = { + self.bot.filter_list_cache = { "file_format.True": [ {"content": ".first"}, {"content": ".second"}, -- cgit v1.2.3 From e0837f4f6dd7c5c2d6fc0811dccfaf1ecae768ba Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Wed, 29 Jul 2020 20:14:52 +0200 Subject: Restructure bot.filter_list_cache. This is an optimization designed to eliminate all the list comprehensions we were doing inside antimalware and filtering. The cache is now structured so that the content is the key and the metadata is the value. --- bot/bot.py | 8 ++++---- bot/cogs/antimalware.py | 2 +- bot/cogs/filter_lists.py | 18 +++++++++--------- bot/cogs/filtering.py | 3 +-- tests/bot/cogs/test_antimalware.py | 10 +++++----- 5 files changed, 20 insertions(+), 21 deletions(-) (limited to 'tests') diff --git a/bot/bot.py b/bot/bot.py index 5deb986ec..4492feaa9 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -35,7 +35,7 @@ class Bot(commands.Bot): self.redis_ready = asyncio.Event() self.redis_closed = False self.api_client = api.APIClient(loop=self.loop) - self.filter_list_cache = defaultdict(list) + self.filter_list_cache = defaultdict(dict) self._connector = None self._resolver = None @@ -169,14 +169,14 @@ class Bot(commands.Bot): """Add an item to the bots filter_list_cache.""" type_ = item["type"] allowed = item["allowed"] - metadata = { + content = item["content"] + + self.filter_list_cache[f"{type_}.{allowed}"][content] = { "id": item["id"], - "content": item["content"], "comment": item["comment"], "created_at": item["created_at"], "updated_at": item["updated_at"], } - self.filter_list_cache[f"{type_}.{allowed}"].append(metadata) async def login(self, *args, **kwargs) -> None: """Re-create the connector and set up sessions before logging into Discord.""" diff --git a/bot/cogs/antimalware.py b/bot/cogs/antimalware.py index 9a100b3fc..c76bd2c60 100644 --- a/bot/cogs/antimalware.py +++ b/bot/cogs/antimalware.py @@ -40,7 +40,7 @@ class AntiMalware(Cog): def _get_whitelisted_file_formats(self) -> list: """Get the file formats currently on the whitelist.""" - return [item['content'] for item in self.bot.filter_list_cache['file_format.True']] + return self.bot.filter_list_cache['FILE_FORMAT.True'].keys() def _get_disallowed_extensions(self, message: Message) -> t.Iterable[str]: """Get an iterable containing all the disallowed extensions of attachments.""" diff --git a/bot/cogs/filter_lists.py b/bot/cogs/filter_lists.py index a93de2de9..3331be014 100644 --- a/bot/cogs/filter_lists.py +++ b/bot/cogs/filter_lists.py @@ -88,16 +88,16 @@ class FilterLists(Cog): # Find the content and delete it. log.trace(f"Trying to delete the {content} item from the {list_type} {allow_type}") - for allow_list in self.bot.filter_list_cache[f"{list_type}.{allowed}"]: - if content == allow_list.get("content"): - item = allow_list + for allow_list, metadata in self.bot.filter_list_cache[f"{list_type}.{allowed}"].items(): + if content == allow_list: + item = metadata break if item is not None: await self.bot.api_client.delete( - f"bot/filter-lists/{item.get('id')}" + f"bot/filter-lists/{item['id']}" ) - self.bot.filter_list_cache[f"{list_type}.{allowed}"].remove(item) + del self.bot.filter_list_cache[f"{list_type}.{allowed}"][content] await ctx.message.add_reaction("✅") async def _list_all_data(self, ctx: Context, allowed: bool, list_type: ValidFilterListType) -> None: @@ -107,11 +107,11 @@ class FilterLists(Cog): # Build a list of lines we want to show in the paginator lines = [] - for item in result: - line = f"• `{item.get('content')}`" + for content, metadata in result.items(): + line = f"• `{content}`" - if item.get("comment"): - line += f" - {item.get('comment')}" + if metadata.get("comment"): + line += f" - {metadata.get('comment')}" lines.append(line) lines = sorted(lines) diff --git a/bot/cogs/filtering.py b/bot/cogs/filtering.py index 7787d396d..0951cb740 100644 --- a/bot/cogs/filtering.py +++ b/bot/cogs/filtering.py @@ -101,8 +101,7 @@ class Filtering(Cog): def _get_filterlist_items(self, list_type: str, *, allowed: bool) -> list: """Fetch items from the filter_list_cache.""" - items = self.bot.filter_list_cache[f"{list_type.upper()}.{allowed}"] - return [item["content"] for item in items] + return self.bot.filter_list_cache[f"{list_type.upper()}.{allowed}"].keys() @staticmethod def _expand_spoilers(text: str) -> str: diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py index 664fa8f19..82eadf226 100644 --- a/tests/bot/cogs/test_antimalware.py +++ b/tests/bot/cogs/test_antimalware.py @@ -15,11 +15,11 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): """Sets up fresh objects for each test.""" self.bot = MockBot() self.bot.filter_list_cache = { - "file_format.True": [ - {"content": ".first"}, - {"content": ".second"}, - {"content": ".third"} - ] + "file_format.True": { + ".first": {}, + ".second": {}, + ".third": {}, + } } self.cog = antimalware.AntiMalware(self.bot) self.message = MockMessage() -- cgit v1.2.3 From 0cfc918c6d68764c380f1188f3bc5508e6b27030 Mon Sep 17 00:00:00 2001 From: Leon Sandøy Date: Wed, 29 Jul 2020 20:24:06 +0200 Subject: Fix broken antimalware tests. --- tests/bot/cogs/test_antimalware.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py index 82eadf226..ecb7abf00 100644 --- a/tests/bot/cogs/test_antimalware.py +++ b/tests/bot/cogs/test_antimalware.py @@ -15,7 +15,7 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): """Sets up fresh objects for each test.""" self.bot = MockBot() self.bot.filter_list_cache = { - "file_format.True": { + "FILE_FORMAT.True": { ".first": {}, ".second": {}, ".third": {}, -- cgit v1.2.3 From ed4ebbf5f7ee751f87554831e277d270cf36ac40 Mon Sep 17 00:00:00 2001 From: Joseph Banks Date: Fri, 14 Aug 2020 21:19:04 +0100 Subject: Update tests for user commands --- tests/bot/cogs/test_information.py | 87 ++++++++++++++++++++++++-------------- 1 file changed, 55 insertions(+), 32 deletions(-) (limited to 'tests') diff --git a/tests/bot/cogs/test_information.py b/tests/bot/cogs/test_information.py index 79c0e0ad3..77b0ddf17 100644 --- a/tests/bot/cogs/test_information.py +++ b/tests/bot/cogs/test_information.py @@ -215,10 +215,10 @@ class UserInfractionHelperMethodTests(unittest.TestCase): with self.subTest(method=method, api_response=api_response, expected_lines=expected_lines): self.bot.api_client.get.return_value = api_response - expected_output = "\n".join(default_header + expected_lines) + expected_output = "\n".join(expected_lines) actual_output = asyncio.run(method(self.member)) - self.assertEqual(expected_output, actual_output) + self.assertEqual((default_header, expected_output), actual_output) def test_basic_user_infraction_counts_returns_correct_strings(self): """The method should correctly list both the total and active number of non-hidden infractions.""" @@ -249,7 +249,7 @@ class UserInfractionHelperMethodTests(unittest.TestCase): }, ) - header = ["**Infractions**"] + header = "Infractions" self._method_subtests(self.cog.basic_user_infraction_counts, test_values, header) @@ -258,7 +258,7 @@ class UserInfractionHelperMethodTests(unittest.TestCase): test_values = ( { "api response": [], - "expected_lines": ["This user has never received an infraction."], + "expected_lines": ["No infractions"], }, # Shows non-hidden inactive infraction as expected { @@ -304,7 +304,7 @@ class UserInfractionHelperMethodTests(unittest.TestCase): }, ) - header = ["**Infractions**"] + header = "Infractions" self._method_subtests(self.cog.expanded_user_infraction_counts, test_values, header) @@ -313,15 +313,15 @@ class UserInfractionHelperMethodTests(unittest.TestCase): test_values = ( { "api response": [], - "expected_lines": ["This user has never been nominated."], + "expected_lines": ["No nominations"], }, { "api response": [{'active': True}], - "expected_lines": ["This user is **currently** nominated (1 nomination in total)."], + "expected_lines": ["This user is **currently** nominated", "(1 nomination in total)"], }, { "api response": [{'active': True}, {'active': False}], - "expected_lines": ["This user is **currently** nominated (2 nominations in total)."], + "expected_lines": ["This user is **currently** nominated", "(2 nominations in total)"], }, { "api response": [{'active': False}], @@ -334,7 +334,7 @@ class UserInfractionHelperMethodTests(unittest.TestCase): ) - header = ["**Nominations**"] + header = "Nominations" self._method_subtests(self.cog.user_nomination_counts, test_values, header) @@ -350,7 +350,10 @@ class UserEmbedTests(unittest.TestCase): self.bot.api_client.get = unittest.mock.AsyncMock() self.cog = information.Information(self.bot) - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + @unittest.mock.patch( + f"{COG_PATH}.basic_user_infraction_counts", + new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions")) + ) def test_create_user_embed_uses_string_representation_of_user_in_title_if_nick_is_not_available(self): """The embed should use the string representation of the user if they don't have a nick.""" ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) @@ -362,7 +365,10 @@ class UserEmbedTests(unittest.TestCase): self.assertEqual(embed.title, "Mr. Hemlock") - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + @unittest.mock.patch( + f"{COG_PATH}.basic_user_infraction_counts", + new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions")) + ) def test_create_user_embed_uses_nick_in_title_if_available(self): """The embed should use the nick if it's available.""" ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) @@ -374,7 +380,10 @@ class UserEmbedTests(unittest.TestCase): self.assertEqual(embed.title, "Cat lover (Mr. Hemlock)") - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + @unittest.mock.patch( + f"{COG_PATH}.basic_user_infraction_counts", + new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions")) + ) def test_create_user_embed_ignores_everyone_role(self): """Created `!user` embeds should not contain mention of the @everyone-role.""" ctx = helpers.MockContext(channel=helpers.MockTextChannel(id=1)) @@ -386,8 +395,8 @@ class UserEmbedTests(unittest.TestCase): embed = asyncio.run(self.cog.create_user_embed(ctx, user)) - self.assertIn("&Admins", embed.description) - self.assertNotIn("&Everyone", embed.description) + self.assertIn("&Admins", embed.fields[1].value) + self.assertNotIn("&Everyone", embed.fields[1].value) @unittest.mock.patch(f"{COG_PATH}.expanded_user_infraction_counts", new_callable=unittest.mock.AsyncMock) @unittest.mock.patch(f"{COG_PATH}.user_nomination_counts", new_callable=unittest.mock.AsyncMock) @@ -398,8 +407,8 @@ class UserEmbedTests(unittest.TestCase): moderators_role = helpers.MockRole(name='Moderators') moderators_role.colour = 100 - infraction_counts.return_value = "expanded infractions info" - nomination_counts.return_value = "nomination info" + infraction_counts.return_value = ("Infractions", "expanded infractions info") + nomination_counts.return_value = ("Nominations", "nomination info") user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) embed = asyncio.run(self.cog.create_user_embed(ctx, user)) @@ -409,20 +418,19 @@ class UserEmbedTests(unittest.TestCase): self.assertEqual( textwrap.dedent(f""" - **User Information** Created: {"1 year ago"} Profile: {user.mention} ID: {user.id} + """).strip(), + embed.fields[0].value + ) - **Member Information** + self.assertEqual( + textwrap.dedent(f""" Joined: {"1 year ago"} Roles: &Moderators - - expanded infractions info - - nomination info """).strip(), - embed.description + embed.fields[1].value ) @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new_callable=unittest.mock.AsyncMock) @@ -433,7 +441,7 @@ class UserEmbedTests(unittest.TestCase): moderators_role = helpers.MockRole(name='Moderators') moderators_role.colour = 100 - infraction_counts.return_value = "basic infractions info" + infraction_counts.return_value = ("Infractions", "basic infractions info") user = helpers.MockMember(id=314, roles=[moderators_role], top_role=moderators_role) embed = asyncio.run(self.cog.create_user_embed(ctx, user)) @@ -442,21 +450,30 @@ class UserEmbedTests(unittest.TestCase): self.assertEqual( textwrap.dedent(f""" - **User Information** Created: {"1 year ago"} Profile: {user.mention} ID: {user.id} + """).strip(), + embed.fields[0].value + ) - **Member Information** + self.assertEqual( + textwrap.dedent(f""" Joined: {"1 year ago"} Roles: &Moderators - - basic infractions info """).strip(), - embed.description + embed.fields[1].value + ) + + self.assertEqual( + "basic infractions info", + embed.fields[3].value ) - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + @unittest.mock.patch( + f"{COG_PATH}.basic_user_infraction_counts", + new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions")) + ) def test_create_user_embed_uses_top_role_colour_when_user_has_roles(self): """The embed should be created with the colour of the top role, if a top role is available.""" ctx = helpers.MockContext() @@ -469,7 +486,10 @@ class UserEmbedTests(unittest.TestCase): self.assertEqual(embed.colour, discord.Colour(moderators_role.colour)) - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + @unittest.mock.patch( + f"{COG_PATH}.basic_user_infraction_counts", + new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions")) + ) def test_create_user_embed_uses_blurple_colour_when_user_has_no_roles(self): """The embed should be created with a blurple colour if the user has no assigned roles.""" ctx = helpers.MockContext() @@ -479,7 +499,10 @@ class UserEmbedTests(unittest.TestCase): self.assertEqual(embed.colour, discord.Colour.blurple()) - @unittest.mock.patch(f"{COG_PATH}.basic_user_infraction_counts", new=unittest.mock.AsyncMock(return_value="")) + @unittest.mock.patch( + f"{COG_PATH}.basic_user_infraction_counts", + new=unittest.mock.AsyncMock(return_value=("Infractions", "basic infractions")) + ) def test_create_user_embed_uses_png_format_of_user_avatar_as_thumbnail(self): """The embed thumbnail should be set to the user's avatar in `png` format.""" ctx = helpers.MockContext() -- cgit v1.2.3 From 520ac0f9871bf6775d76eea753ed2a940704e92d Mon Sep 17 00:00:00 2001 From: MarkKoz Date: Sat, 22 Aug 2020 20:44:48 -0700 Subject: Include root aliases in the command name conflict test --- tests/bot/cogs/test_cogs.py | 1 + 1 file changed, 1 insertion(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_cogs.py b/tests/bot/cogs/test_cogs.py index fdda59a8f..30a04422a 100644 --- a/tests/bot/cogs/test_cogs.py +++ b/tests/bot/cogs/test_cogs.py @@ -53,6 +53,7 @@ class CommandNameTests(unittest.TestCase): """Return a list of all qualified names, including aliases, for the `command`.""" names = [f"{command.full_parent_name} {alias}".strip() for alias in command.aliases] names.append(command.qualified_name) + names += getattr(command, "root_aliases", []) return names -- cgit v1.2.3 From b7644aa822def549e2591b53c69af3cf44355ac9 Mon Sep 17 00:00:00 2001 From: Xithrius Date: Mon, 31 Aug 2020 19:56:24 -0700 Subject: Removed ImagePaginator testing. --- tests/bot/test_pagination.py | 15 --------------- 1 file changed, 15 deletions(-) (limited to 'tests') diff --git a/tests/bot/test_pagination.py b/tests/bot/test_pagination.py index ce880d457..630f2516d 100644 --- a/tests/bot/test_pagination.py +++ b/tests/bot/test_pagination.py @@ -44,18 +44,3 @@ class LinePaginatorTests(TestCase): self.paginator.add_line('x' * (self.paginator.scale_to_size + 1)) # Note: item at index 1 is the truncated line, index 0 is prefix self.assertEqual(self.paginator._current_page[1], 'x' * self.paginator.scale_to_size) - - -class ImagePaginatorTests(TestCase): - """Tests functionality of the `ImagePaginator`.""" - - def setUp(self): - """Create a paginator for the test method.""" - self.paginator = pagination.ImagePaginator() - - def test_add_image_appends_image(self): - """`add_image` appends the image to the image list.""" - image = 'lemon' - self.paginator.add_image(image) - - assert self.paginator.images == [image] -- cgit v1.2.3 From 1a47f5d80f2f91c3da5a9626e9a6694381d49cd0 Mon Sep 17 00:00:00 2001 From: wookie184 Date: Tue, 1 Sep 2020 12:22:43 +0100 Subject: Fixed old tests and added 2 new ones --- tests/bot/cogs/test_antimalware.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) (limited to 'tests') diff --git a/tests/bot/cogs/test_antimalware.py b/tests/bot/cogs/test_antimalware.py index ecb7abf00..f50c0492d 100644 --- a/tests/bot/cogs/test_antimalware.py +++ b/tests/bot/cogs/test_antimalware.py @@ -23,6 +23,8 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): } self.cog = antimalware.AntiMalware(self.bot) self.message = MockMessage() + self.message.webhook_id = None + self.message.author.bot = None self.whitelist = [".first", ".second", ".third"] async def test_message_with_allowed_attachment(self): @@ -48,6 +50,26 @@ class AntiMalwareCogTests(unittest.IsolatedAsyncioTestCase): self.message.delete.assert_not_called() + async def test_webhook_message_with_illegal_extension(self): + """A webhook message containing an illegal extension should be ignored.""" + attachment = MockAttachment(filename="python.disallowed") + self.message.webhook_id = 697140105563078727 + self.message.attachments = [attachment] + + await self.cog.on_message(self.message) + + self.message.delete.assert_not_called() + + async def test_bot_message_with_illegal_extension(self): + """A bot message containing an illegal extension should be ignored.""" + attachment = MockAttachment(filename="python.disallowed") + self.message.author.bot = 409107086526644234 + self.message.attachments = [attachment] + + await self.cog.on_message(self.message) + + self.message.delete.assert_not_called() + async def test_message_with_illegal_extension_gets_deleted(self): """A message containing an illegal extension should send an embed.""" attachment = MockAttachment(filename="python.disallowed") -- cgit v1.2.3