diff options
author | 2023-03-23 19:59:05 +0200 | |
---|---|---|
committer | 2023-03-23 19:59:05 +0200 | |
commit | 509c7968dab875f8e3e7934647c757a2a73f724b (patch) | |
tree | e2848cc84a408d7d6a704b65093ad984d59575c1 | |
parent | Fix filtering tests (diff) |
Add support for snekbox IO in the new filtering system
-rw-r--r-- | bot/exts/filtering/_filter_context.py | 19 | ||||
-rw-r--r-- | bot/exts/filtering/_filter_lists/antispam.py | 4 | ||||
-rw-r--r-- | bot/exts/filtering/_filter_lists/domain.py | 2 | ||||
-rw-r--r-- | bot/exts/filtering/_filter_lists/extension.py | 56 | ||||
-rw-r--r-- | bot/exts/filtering/_filter_lists/invite.py | 2 | ||||
-rw-r--r-- | bot/exts/filtering/_filter_lists/token.py | 2 | ||||
-rw-r--r-- | bot/exts/filtering/_filters/unique/discord_token.py | 2 | ||||
-rw-r--r-- | bot/exts/filtering/_filters/unique/everyone.py | 2 | ||||
-rw-r--r-- | bot/exts/filtering/_filters/unique/webhook.py | 2 | ||||
-rw-r--r-- | bot/exts/filtering/_settings_types/actions/remove_context.py | 4 | ||||
-rw-r--r-- | bot/exts/filtering/_ui/ui.py | 6 | ||||
-rw-r--r-- | bot/exts/filtering/filtering.py | 17 | ||||
-rw-r--r-- | bot/exts/utils/snekbox/_cog.py | 84 | ||||
-rw-r--r-- | bot/exts/utils/snekbox/_io.py | 10 | ||||
-rw-r--r-- | tests/bot/exts/filtering/test_extension_filter.py | 30 | ||||
-rw-r--r-- | tests/bot/exts/utils/snekbox/test_snekbox.py | 8 |
16 files changed, 134 insertions, 116 deletions
diff --git a/bot/exts/filtering/_filter_context.py b/bot/exts/filtering/_filter_context.py index 8e1ed5788..483706e2a 100644 --- a/bot/exts/filtering/_filter_context.py +++ b/bot/exts/filtering/_filter_context.py @@ -5,12 +5,14 @@ from collections.abc import Callable, Coroutine, Iterable from dataclasses import dataclass, field, replace from enum import Enum, auto +import discord from discord import DMChannel, Embed, Member, Message, TextChannel, Thread, User from bot.utils.message_cache import MessageCache if typing.TYPE_CHECKING: from bot.exts.filtering._filters.filter import Filter + from bot.exts.utils.snekbox._io import FileAttachment class Event(Enum): @@ -19,6 +21,7 @@ class Event(Enum): MESSAGE = auto() MESSAGE_EDIT = auto() NICKNAME = auto() + SNEKBOX = auto() @dataclass @@ -32,6 +35,7 @@ class FilterContext: content: str | Iterable # What actually needs filtering. The Iterable type depends on the filter list. message: Message | None # The message involved embeds: list[Embed] = field(default_factory=list) # Any embeds involved + attachments: list[discord.Attachment | FileAttachment] = field(default_factory=list) # Any attachments sent. before_message: Message | None = None message_cache: MessageCache | None = None # Output context @@ -45,11 +49,12 @@ class FilterContext: notification_domain: str = "" # A domain to send the user for context filter_info: dict['Filter', str] = field(default_factory=dict) # Additional info from a filter. messages_deletion: bool = False # Whether the messages were deleted. Can't upload deletion log otherwise. + blocked_exts: set[str] = field(default_factory=set) # Any extensions blocked (used for snekbox) # Additional actions to perform additional_actions: list[Callable[[FilterContext], Coroutine]] = field(default_factory=list) related_messages: set[Message] = field(default_factory=set) # Deletion will include these. related_channels: set[TextChannel | Thread | DMChannel] = field(default_factory=set) - attachments: dict[int, list[str]] = field(default_factory=dict) # Message ID to attachment URLs. + uploaded_attachments: dict[int, list[str]] = field(default_factory=dict) # Message ID to attachment URLs. upload_deletion_logs: bool = True # Whether it's allowed to upload deletion logs. @classmethod @@ -57,7 +62,17 @@ class FilterContext: cls, event: Event, message: Message, before: Message | None = None, cache: MessageCache | None = None ) -> FilterContext: """Create a filtering context from the attributes of a message.""" - return cls(event, message.author, message.channel, message.content, message, message.embeds, before, cache) + return cls( + event, + message.author, + message.channel, + message.content, + message, + message.embeds, + message.attachments, + before, + cache + ) def replace(self, **changes) -> FilterContext: """Return a new context object assigning new values to the specified fields.""" diff --git a/bot/exts/filtering/_filter_lists/antispam.py b/bot/exts/filtering/_filter_lists/antispam.py index 0e7ab2bdc..ba20051fc 100644 --- a/bot/exts/filtering/_filter_lists/antispam.py +++ b/bot/exts/filtering/_filter_lists/antispam.py @@ -171,7 +171,9 @@ class DeletionContext: new_ctx.related_channels = reduce( or_, (other_ctx.related_channels for other_ctx in other_contexts), ctx.related_channels ) | {ctx.channel for ctx in other_contexts} - new_ctx.attachments = reduce(or_, (other_ctx.attachments for other_ctx in other_contexts), ctx.attachments) + new_ctx.uploaded_attachments = reduce( + or_, (other_ctx.uploaded_attachments for other_ctx in other_contexts), ctx.uploaded_attachments + ) new_ctx.upload_deletion_logs = True new_ctx.messages_deletion = all(ctx.messages_deletion for ctx in self.contexts) diff --git a/bot/exts/filtering/_filter_lists/domain.py b/bot/exts/filtering/_filter_lists/domain.py index f4062edfe..091fd14e0 100644 --- a/bot/exts/filtering/_filter_lists/domain.py +++ b/bot/exts/filtering/_filter_lists/domain.py @@ -31,7 +31,7 @@ class DomainsList(FilterList[DomainFilter]): def __init__(self, filtering_cog: Filtering): super().__init__() - filtering_cog.subscribe(self, Event.MESSAGE, Event.MESSAGE_EDIT) + filtering_cog.subscribe(self, Event.MESSAGE, Event.MESSAGE_EDIT, Event.SNEKBOX) def get_filter_type(self, content: str) -> type[Filter]: """Get a subclass of filter matching the filter list and the filter's content.""" diff --git a/bot/exts/filtering/_filter_lists/extension.py b/bot/exts/filtering/_filter_lists/extension.py index a739d7191..868fde2b2 100644 --- a/bot/exts/filtering/_filter_lists/extension.py +++ b/bot/exts/filtering/_filter_lists/extension.py @@ -49,7 +49,7 @@ class ExtensionsList(FilterList[ExtensionFilter]): def __init__(self, filtering_cog: Filtering): super().__init__() - filtering_cog.subscribe(self, Event.MESSAGE) + filtering_cog.subscribe(self, Event.MESSAGE, Event.SNEKBOX) self._whitelisted_description = None def get_filter_type(self, content: str) -> type[Filter]: @@ -66,7 +66,7 @@ class ExtensionsList(FilterList[ExtensionFilter]): ) -> tuple[ActionSettings | None, list[str], dict[ListType, list[Filter]]]: """Dispatch the given event to the list's filters, and return actions to take and messages to relay to mods.""" # Return early if the message doesn't have attachments. - if not ctx.message or not ctx.message.attachments: + if not ctx.message or not ctx.attachments: return None, [], {} _, failed = self[ListType.ALLOW].defaults.validations.evaluate(ctx) @@ -75,7 +75,7 @@ class ExtensionsList(FilterList[ExtensionFilter]): # Find all extensions in the message. all_ext = { - (splitext(attachment.filename.lower())[1], attachment.filename) for attachment in ctx.message.attachments + (splitext(attachment.filename.lower())[1], attachment.filename) for attachment in ctx.attachments } new_ctx = ctx.replace(content={ext for ext, _ in all_ext}) # And prepare the context for the filters to read. triggered = [ @@ -86,31 +86,37 @@ class ExtensionsList(FilterList[ExtensionFilter]): # See if there are any extensions left which aren't allowed. not_allowed = {ext: filename for ext, filename in all_ext if ext not in allowed_ext} + if ctx.event == Event.SNEKBOX: + not_allowed = {ext: filename for ext, filename in not_allowed.items() if ext not in TXT_LIKE_FILES} + if not not_allowed: # Yes, it's a double negative. Meaning all attachments are allowed :) return None, [], {ListType.ALLOW: triggered} - # Something is disallowed. - if ".py" in not_allowed: - # Provide a pastebin link for .py files. - ctx.dm_embed = PY_EMBED_DESCRIPTION - elif txt_extensions := {ext for ext in TXT_LIKE_FILES if ext in not_allowed}: - # Work around Discord auto-conversion of messages longer than 2000 chars to .txt - cmd_channel = bot.instance.get_channel(Channels.bot_commands) - ctx.dm_embed = TXT_EMBED_DESCRIPTION.format( - blocked_extension=txt_extensions.pop(), - cmd_channel_mention=cmd_channel.mention - ) - else: - meta_channel = bot.instance.get_channel(Channels.meta) - if not self._whitelisted_description: - self._whitelisted_description = ', '.join( - filter_.content for filter_ in self[ListType.ALLOW].filters.values() + # At this point, something is disallowed. + if ctx.event != Event.SNEKBOX: # Don't post the embed if it's a snekbox response. + if ".py" in not_allowed: + # Provide a pastebin link for .py files. + ctx.dm_embed = PY_EMBED_DESCRIPTION + elif txt_extensions := {ext for ext in TXT_LIKE_FILES if ext in not_allowed}: + # Work around Discord auto-conversion of messages longer than 2000 chars to .txt + cmd_channel = bot.instance.get_channel(Channels.bot_commands) + ctx.dm_embed = TXT_EMBED_DESCRIPTION.format( + blocked_extension=txt_extensions.pop(), + cmd_channel_mention=cmd_channel.mention + ) + else: + meta_channel = bot.instance.get_channel(Channels.meta) + if not self._whitelisted_description: + self._whitelisted_description = ', '.join( + filter_.content for filter_ in self[ListType.ALLOW].filters.values() + ) + ctx.dm_embed = DISALLOWED_EMBED_DESCRIPTION.format( + joined_whitelist=self._whitelisted_description, + blocked_extensions_str=", ".join(not_allowed), + meta_channel_mention=meta_channel.mention, ) - ctx.dm_embed = DISALLOWED_EMBED_DESCRIPTION.format( - joined_whitelist=self._whitelisted_description, - blocked_extensions_str=", ".join(not_allowed), - meta_channel_mention=meta_channel.mention, - ) ctx.matches += not_allowed.values() - return self[ListType.ALLOW].defaults.actions, [f"`{ext}`" for ext in not_allowed], {ListType.ALLOW: triggered} + ctx.blocked_exts |= set(not_allowed) + actions = self[ListType.ALLOW].defaults.actions if ctx.event != Event.SNEKBOX else None + return actions, [f"`{ext}`" for ext in not_allowed], {ListType.ALLOW: triggered} diff --git a/bot/exts/filtering/_filter_lists/invite.py b/bot/exts/filtering/_filter_lists/invite.py index bd0eaa122..b9732a6dc 100644 --- a/bot/exts/filtering/_filter_lists/invite.py +++ b/bot/exts/filtering/_filter_lists/invite.py @@ -37,7 +37,7 @@ class InviteList(FilterList[InviteFilter]): def __init__(self, filtering_cog: Filtering): super().__init__() - filtering_cog.subscribe(self, Event.MESSAGE, Event.MESSAGE_EDIT) + filtering_cog.subscribe(self, Event.MESSAGE, Event.MESSAGE_EDIT, Event.SNEKBOX) def get_filter_type(self, content: str) -> type[Filter]: """Get a subclass of filter matching the filter list and the filter's content.""" diff --git a/bot/exts/filtering/_filter_lists/token.py b/bot/exts/filtering/_filter_lists/token.py index f5da28bb5..0c591ac3b 100644 --- a/bot/exts/filtering/_filter_lists/token.py +++ b/bot/exts/filtering/_filter_lists/token.py @@ -32,7 +32,7 @@ class TokensList(FilterList[TokenFilter]): def __init__(self, filtering_cog: Filtering): super().__init__() - filtering_cog.subscribe(self, Event.MESSAGE, Event.MESSAGE_EDIT, Event.NICKNAME) + filtering_cog.subscribe(self, Event.MESSAGE, Event.MESSAGE_EDIT, Event.NICKNAME, Event.SNEKBOX) def get_filter_type(self, content: str) -> type[Filter]: """Get a subclass of filter matching the filter list and the filter's content.""" diff --git a/bot/exts/filtering/_filters/unique/discord_token.py b/bot/exts/filtering/_filters/unique/discord_token.py index 6174ee30b..f4b9cc741 100644 --- a/bot/exts/filtering/_filters/unique/discord_token.py +++ b/bot/exts/filtering/_filters/unique/discord_token.py @@ -61,7 +61,7 @@ class DiscordTokenFilter(UniqueFilter): """Scans messages for potential discord client tokens and removes them.""" name = "discord_token" - events = (Event.MESSAGE, Event.MESSAGE_EDIT) + events = (Event.MESSAGE, Event.MESSAGE_EDIT, Event.SNEKBOX) extra_fields_type = ExtraDiscordTokenSettings @property diff --git a/bot/exts/filtering/_filters/unique/everyone.py b/bot/exts/filtering/_filters/unique/everyone.py index a32e67cc5..e49ede82f 100644 --- a/bot/exts/filtering/_filters/unique/everyone.py +++ b/bot/exts/filtering/_filters/unique/everyone.py @@ -16,7 +16,7 @@ class EveryoneFilter(UniqueFilter): """Filter messages which contain `@everyone` and `@here` tags outside a codeblock.""" name = "everyone" - events = (Event.MESSAGE, Event.MESSAGE_EDIT) + events = (Event.MESSAGE, Event.MESSAGE_EDIT, Event.SNEKBOX) async def triggered_on(self, ctx: FilterContext) -> bool: """Search for the filter's content within a given context.""" diff --git a/bot/exts/filtering/_filters/unique/webhook.py b/bot/exts/filtering/_filters/unique/webhook.py index 965ef42eb..4e1e2e44d 100644 --- a/bot/exts/filtering/_filters/unique/webhook.py +++ b/bot/exts/filtering/_filters/unique/webhook.py @@ -22,7 +22,7 @@ class WebhookFilter(UniqueFilter): """Scan messages to detect Discord webhooks links.""" name = "webhook" - events = (Event.MESSAGE, Event.MESSAGE_EDIT) + events = (Event.MESSAGE, Event.MESSAGE_EDIT, Event.SNEKBOX) @property def mod_log(self) -> ModLog | None: diff --git a/bot/exts/filtering/_settings_types/actions/remove_context.py b/bot/exts/filtering/_settings_types/actions/remove_context.py index 7ead88818..5ec2613f4 100644 --- a/bot/exts/filtering/_settings_types/actions/remove_context.py +++ b/bot/exts/filtering/_settings_types/actions/remove_context.py @@ -28,8 +28,8 @@ async def upload_messages_attachments(ctx: FilterContext, messages: list[Message return destination = messages[0].guild.get_channel(Channels.attachment_log) for message in messages: - if message.attachments and message.id not in ctx.attachments: - ctx.attachments[message.id] = await send_attachments(message, destination, link_large=False) + if message.attachments and message.id not in ctx.uploaded_attachments: + ctx.uploaded_attachments[message.id] = await send_attachments(message, destination, link_large=False) class RemoveContext(ActionEntry): diff --git a/bot/exts/filtering/_ui/ui.py b/bot/exts/filtering/_ui/ui.py index 157906d6b..8cd2864a9 100644 --- a/bot/exts/filtering/_ui/ui.py +++ b/bot/exts/filtering/_ui/ui.py @@ -59,10 +59,10 @@ async def _build_alert_message_content(ctx: FilterContext, current_message_lengt # For multiple messages and those with attachments or excessive newlines, use the logs API if ctx.messages_deletion and ctx.upload_deletion_logs and any(( ctx.related_messages, - len(ctx.attachments) > 0, + len(ctx.uploaded_attachments) > 0, ctx.content.count('\n') > 15 )): - url = await upload_log(ctx.related_messages, bot.instance.user.id, ctx.attachments) + url = await upload_log(ctx.related_messages, bot.instance.user.id, ctx.uploaded_attachments) return f"A complete log of the offending messages can be found [here]({url})" alert_content = escape_markdown(ctx.content) @@ -70,7 +70,7 @@ async def _build_alert_message_content(ctx: FilterContext, current_message_lengt if len(alert_content) > remaining_chars: if ctx.messages_deletion and ctx.upload_deletion_logs: - url = await upload_log([ctx.message], bot.instance.user.id, ctx.attachments) + url = await upload_log([ctx.message], bot.instance.user.id, ctx.uploaded_attachments) log_site_msg = f"The full message can be found [here]({url})" # 7 because that's the length of "[...]\n\n" return alert_content[:remaining_chars - (7 + len(log_site_msg))] + "[...]\n\n" + log_site_msg diff --git a/bot/exts/filtering/filtering.py b/bot/exts/filtering/filtering.py index c4417e5e0..2a7f8f81f 100644 --- a/bot/exts/filtering/filtering.py +++ b/bot/exts/filtering/filtering.py @@ -40,6 +40,7 @@ from bot.exts.filtering._ui.ui import ( ) from bot.exts.filtering._utils import past_tense, repr_equals, starting_value, to_serializable from bot.exts.moderation.infraction.infractions import COMP_BAN_DURATION, COMP_BAN_REASON +from bot.exts.utils.snekbox._io import FileAttachment from bot.log import get_logger from bot.pagination import LinePaginator from bot.utils.channel import is_mod_channel @@ -251,24 +252,30 @@ class Filtering(Cog): ctx = FilterContext(Event.NICKNAME, member, None, member.display_name, None) await self._check_bad_name(ctx) - async def filter_snekbox_output(self, snekbox_result: str, msg: Message) -> bool: + async def filter_snekbox_output( + self, stdout: str, files: list[FileAttachment], msg: Message + ) -> tuple[bool, set[str]]: """ Filter the result of a snekbox command to see if it violates any of our rules, and then respond accordingly. Also requires the original message, to check whether to filter and for alerting. Any action (deletion, infraction) will be applied in the context of the original message. - Returns whether a filter was triggered or not. + Returns whether the output should be blocked, as well as a list of blocked file extensions. """ - ctx = FilterContext.from_message(Event.MESSAGE, msg).replace(content=snekbox_result) + content = stdout + if files: # Filter the filenames as well. + content += "\n\n" + "\n".join(file.filename for file in files) + ctx = FilterContext.from_message(Event.SNEKBOX, msg).replace(content=content, attachments=files) + result_actions, list_messages, triggers = await self._resolve_action(ctx) if result_actions: await result_actions.action(ctx) if ctx.send_alert: await self._send_alert(ctx, list_messages) - self._increment_stats(triggers) - return result_actions is not None + self._increment_stats(triggers) + return result_actions is not None, ctx.blocked_exts # endregion # region: blacklist commands diff --git a/bot/exts/utils/snekbox/_cog.py b/bot/exts/utils/snekbox/_cog.py index 567fe6c24..d7e8bc93c 100644 --- a/bot/exts/utils/snekbox/_cog.py +++ b/bot/exts/utils/snekbox/_cog.py @@ -14,9 +14,8 @@ from pydis_core.utils import interactions from pydis_core.utils.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX from bot.bot import Bot -from bot.constants import Channels, Emojis, MODERATION_ROLES, Roles, STAFF_PARTNERS_COMMUNITY_ROLES, URLs +from bot.constants import Channels, Emojis, MODERATION_ROLES, Roles, URLs from bot.decorators import redirect_output -from bot.exts.events.code_jams._channels import CATEGORY_NAME as JAM_CATEGORY_NAME from bot.exts.filtering._filter_lists.extension import TXT_LIKE_FILES from bot.exts.help_channels._channel import is_help_forum_post from bot.exts.utils.snekbox._eval import EvalJob, EvalResult @@ -288,37 +287,22 @@ class Snekbox(Cog): return output, paste_link - def get_extensions_whitelist(self) -> set[str]: - """Return a set of whitelisted file extensions.""" - return set(self.bot.filter_list_cache['FILE_FORMAT.True'].keys()) | TXT_LIKE_FILES - - def _filter_files(self, ctx: Context, files: list[FileAttachment]) -> FilteredFiles: + def _filter_files(self, ctx: Context, files: list[FileAttachment], blocked_exts: set[str]) -> FilteredFiles: """Filter to restrict files to allowed extensions. Return a named tuple of allowed and blocked files lists.""" - # Check if user is staff, if is, return - # Since we only care that roles exist to iterate over, check for the attr rather than a User/Member instance - if hasattr(ctx.author, "roles") and any(role.id in STAFF_PARTNERS_COMMUNITY_ROLES for role in ctx.author.roles): - return FilteredFiles(files, []) - # Ignore code jam channels - if getattr(ctx.channel, "category", None) and ctx.channel.category.name == JAM_CATEGORY_NAME: - return FilteredFiles(files, []) - - # Get whitelisted extensions - whitelist = self.get_extensions_whitelist() - # Filter files into allowed and blocked blocked = [] allowed = [] for file in files: - if file.suffix in whitelist: - allowed.append(file) - else: + if file.suffix in blocked_exts: blocked.append(file) + else: + allowed.append(file) if blocked: blocked_str = ", ".join(f.suffix for f in blocked) log.info( f"User '{ctx.author}' ({ctx.author.id}) uploaded blacklisted file(s) in eval: {blocked_str}", - extra={"attachment_list": [f.path for f in files]} + extra={"attachment_list": [f.filename for f in files]} ) return FilteredFiles(allowed, blocked) @@ -365,31 +349,8 @@ class Snekbox(Cog): else: self.bot.stats.incr("snekbox.python.success") - # Filter file extensions - allowed, blocked = self._filter_files(ctx, result.files) - # Also scan failed files for blocked extensions - failed_files = [FileAttachment(name, b"") for name in result.failed_files] - blocked.extend(self._filter_files(ctx, failed_files).blocked) - # Add notice if any files were blocked - if blocked: - blocked_sorted = sorted(set(f.suffix for f in blocked)) - # Only no extension - if len(blocked_sorted) == 1 and blocked_sorted[0] == "": - blocked_msg = "Files with no extension can't be uploaded." - # Both - elif "" in blocked_sorted: - blocked_str = ", ".join(ext for ext in blocked_sorted if ext) - blocked_msg = ( - f"Files with no extension or disallowed extensions can't be uploaded: **{blocked_str}**" - ) - else: - blocked_str = ", ".join(blocked_sorted) - blocked_msg = f"Files with disallowed extensions can't be uploaded: **{blocked_str}**" - - msg += f"\n{Emojis.failed_file} {blocked_msg}" - # Split text files - text_files = [f for f in allowed if f.suffix in TXT_LIKE_FILES] + text_files = [f for f in result.files if f.suffix in TXT_LIKE_FILES] # Inline until budget, then upload to paste service # Budget is shared with stdout, so subtract what we've already used budget_lines = MAX_OUTPUT_BLOCK_LINES - (output.count("\n") + 1) @@ -417,8 +378,35 @@ class Snekbox(Cog): budget_chars -= len(file_text) filter_cog: Filtering | None = self.bot.get_cog("Filtering") - if filter_cog and (await filter_cog.filter_snekbox_output(msg, ctx.message)): - return await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.") + blocked_exts = set() + # Include failed files in the scan. + failed_files = [FileAttachment(name, b"") for name in result.failed_files] + total_files = result.files + failed_files + if filter_cog: + block_output, blocked_exts = await filter_cog.filter_snekbox_output(msg, total_files, ctx.message) + if block_output: + return await ctx.send("Attempt to circumvent filter detected. Moderator team has been alerted.") + + # Filter file extensions + allowed, blocked = self._filter_files(ctx, result.files, blocked_exts) + blocked.extend(self._filter_files(ctx, failed_files, blocked_exts).blocked) + # Add notice if any files were blocked + if blocked: + blocked_sorted = sorted(set(f.suffix for f in blocked)) + # Only no extension + if len(blocked_sorted) == 1 and blocked_sorted[0] == "": + blocked_msg = "Files with no extension can't be uploaded." + # Both + elif "" in blocked_sorted: + blocked_str = ", ".join(ext for ext in blocked_sorted if ext) + blocked_msg = ( + f"Files with no extension or disallowed extensions can't be uploaded: **{blocked_str}**" + ) + else: + blocked_str = ", ".join(blocked_sorted) + blocked_msg = f"Files with disallowed extensions can't be uploaded: **{blocked_str}**" + + msg += f"\n{Emojis.failed_file} {blocked_msg}" # Upload remaining non-text files files = [f.to_file() for f in allowed if f not in text_files] diff --git a/bot/exts/utils/snekbox/_io.py b/bot/exts/utils/snekbox/_io.py index 9be396335..a45ecec1a 100644 --- a/bot/exts/utils/snekbox/_io.py +++ b/bot/exts/utils/snekbox/_io.py @@ -53,23 +53,23 @@ def normalize_discord_file_name(name: str) -> str: class FileAttachment: """File Attachment from Snekbox eval.""" - path: str + filename: str content: bytes def __repr__(self) -> str: """Return the content as a string.""" content = f"{self.content[:10]}..." if len(self.content) > 10 else self.content - return f"FileAttachment(path={self.path!r}, content={content})" + return f"FileAttachment(path={self.filename!r}, content={content})" @property def suffix(self) -> str: """Return the file suffix.""" - return PurePosixPath(self.path).suffix + return PurePosixPath(self.filename).suffix @property def name(self) -> str: """Return the file name.""" - return PurePosixPath(self.path).name + return PurePosixPath(self.filename).name @classmethod def from_dict(cls, data: dict, size_limit: int = FILE_SIZE_LIMIT) -> FileAttachment: @@ -92,7 +92,7 @@ class FileAttachment: content = content.encode("utf-8") return { - "path": self.path, + "path": self.filename, "content": b64encode(content).decode("ascii"), } diff --git a/tests/bot/exts/filtering/test_extension_filter.py b/tests/bot/exts/filtering/test_extension_filter.py index 0ad41116d..351daa0b4 100644 --- a/tests/bot/exts/filtering/test_extension_filter.py +++ b/tests/bot/exts/filtering/test_extension_filter.py @@ -45,9 +45,9 @@ class ExtensionsListTests(unittest.IsolatedAsyncioTestCase): async def test_message_with_allowed_attachment(self): """Messages with allowed extensions should trigger the whitelist and result in no actions or messages.""" attachment = MockAttachment(filename="python.first") - self.message.attachments = [attachment] + ctx = self.ctx.replace(attachments=[attachment]) - result = await self.filter_list.actions_for(self.ctx) + result = await self.filter_list.actions_for(ctx) self.assertEqual(result, (None, [], {ListType.ALLOW: [self.filter_list[ListType.ALLOW].filters[1]]})) @@ -62,9 +62,9 @@ class ExtensionsListTests(unittest.IsolatedAsyncioTestCase): async def test_message_with_illegal_extension(self): """A message with an illegal extension shouldn't trigger the whitelist, and return some action and message.""" attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] + ctx = self.ctx.replace(attachments=[attachment]) - result = await self.filter_list.actions_for(self.ctx) + result = await self.filter_list.actions_for(ctx) self.assertEqual(result, ({}, ["`.disallowed`"], {ListType.ALLOW: []})) @@ -72,11 +72,11 @@ class ExtensionsListTests(unittest.IsolatedAsyncioTestCase): async def test_python_file_redirect_embed_description(self): """A message containing a .py file should result in an embed redirecting the user to our paste site.""" attachment = MockAttachment(filename="python.py") - self.message.attachments = [attachment] + ctx = self.ctx.replace(attachments=[attachment]) - await self.filter_list.actions_for(self.ctx) + await self.filter_list.actions_for(ctx) - self.assertEqual(self.ctx.dm_embed, extension.PY_EMBED_DESCRIPTION) + self.assertEqual(ctx.dm_embed, extension.PY_EMBED_DESCRIPTION) @patch("bot.instance", BOT) async def test_txt_file_redirect_embed_description(self): @@ -91,12 +91,12 @@ class ExtensionsListTests(unittest.IsolatedAsyncioTestCase): with self.subTest(file_name=file_name, disallowed_extension=disallowed_extension): attachment = MockAttachment(filename=f"{file_name}{disallowed_extension}") - self.message.attachments = [attachment] + ctx = self.ctx.replace(attachments=[attachment]) - await self.filter_list.actions_for(self.ctx) + await self.filter_list.actions_for(ctx) self.assertEqual( - self.ctx.dm_embed, + ctx.dm_embed, extension.TXT_EMBED_DESCRIPTION.format( blocked_extension=disallowed_extension, ) @@ -106,13 +106,13 @@ class ExtensionsListTests(unittest.IsolatedAsyncioTestCase): async def test_other_disallowed_extension_embed_description(self): """Test the description for a non .py/.txt/.json/.csv disallowed extension.""" attachment = MockAttachment(filename="python.disallowed") - self.message.attachments = [attachment] + ctx = self.ctx.replace(attachments=[attachment]) - await self.filter_list.actions_for(self.ctx) + await self.filter_list.actions_for(ctx) meta_channel = BOT.get_channel(Channels.meta) self.assertEqual( - self.ctx.dm_embed, + ctx.dm_embed, extension.DISALLOWED_EMBED_DESCRIPTION.format( joined_whitelist=", ".join(self.whitelist), blocked_extensions_str=".disallowed", @@ -134,6 +134,6 @@ class ExtensionsListTests(unittest.IsolatedAsyncioTestCase): for extensions, expected_disallowed_extensions in test_values: with self.subTest(extensions=extensions, expected_disallowed_extensions=expected_disallowed_extensions): - self.message.attachments = [MockAttachment(filename=f"filename{ext}") for ext in extensions] - result = await self.filter_list.actions_for(self.ctx) + ctx = self.ctx.replace(attachments=[MockAttachment(filename=f"filename{ext}") for ext in extensions]) + result = await self.filter_list.actions_for(ctx) self.assertCountEqual(result[1], expected_disallowed_extensions) diff --git a/tests/bot/exts/utils/snekbox/test_snekbox.py b/tests/bot/exts/utils/snekbox/test_snekbox.py index 9dcf7fd8c..79ac8ea2c 100644 --- a/tests/bot/exts/utils/snekbox/test_snekbox.py +++ b/tests/bot/exts/utils/snekbox/test_snekbox.py @@ -307,7 +307,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.upload_output = AsyncMock() # Should not be called mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, [])) self.bot.get_cog.return_value = mocked_filter_cog job = EvalJob.from_code('MyAwesomeCode') @@ -339,7 +339,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.format_output = AsyncMock(return_value=('Way too long beard', 'lookatmybeard.com')) mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, [])) self.bot.get_cog.return_value = mocked_filter_cog job = EvalJob.from_code("MyAwesomeCode").as_version("3.11") @@ -368,7 +368,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.upload_output = AsyncMock() # This function isn't called mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, [])) self.bot.get_cog.return_value = mocked_filter_cog job = EvalJob.from_code("MyAwesomeCode").as_version("3.11") @@ -396,7 +396,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase): self.cog.upload_output = AsyncMock() # This function isn't called mocked_filter_cog = MagicMock() - mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=False) + mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, [".disallowed"])) self.bot.get_cog.return_value = mocked_filter_cog job = EvalJob.from_code("MyAwesomeCode").as_version("3.11") |