From eacf52f2cef2518891448fa3c9541967836dd54c Mon Sep 17 00:00:00 2001 From: fatelei Date: Thu, 12 Mar 2026 10:51:02 +0800 Subject: [PATCH] feat: support files as context --- api/dify_graph/nodes/llm/node.py | 11 ++++ api/dify_graph/nodes/tool/tool_node.py | 8 +++ .../core/workflow/nodes/llm/test_node.py | 62 +++++++++++++++++++ 3 files changed, 81 insertions(+) diff --git a/api/dify_graph/nodes/llm/node.py b/api/dify_graph/nodes/llm/node.py index b88ff404c0bc3a..196d8fdc1d88fb 100644 --- a/api/dify_graph/nodes/llm/node.py +++ b/api/dify_graph/nodes/llm/node.py @@ -29,6 +29,7 @@ WorkflowNodeExecutionStatus, ) from dify_graph.file import File, FileTransferMethod, FileType, file_manager +from dify_graph.file.constants import maybe_file_object from dify_graph.model_runtime.entities import ( ImagePromptMessageContent, PromptMessage, @@ -76,6 +77,7 @@ StringSegment, ) from extensions.ext_database import db +from factories.file_factory import build_from_mapping from models.dataset import SegmentAttachmentBinding from models.model import UploadFile @@ -679,10 +681,19 @@ def _fetch_context(self, node_data: LLMNodeData): context_str = "" original_retriever_resource: list[RetrievalSourceMetadata] = [] context_files: list[File] = [] + tenant_id = self.require_dify_context().tenant_id for item in context_value_variable.value: if isinstance(item, str): context_str += item + "\n" else: + if isinstance(item, File): + context_files.append(item) + continue + if maybe_file_object(item): + file_obj = build_from_mapping(mapping=item, tenant_id=tenant_id, config=None) + context_files.append(file_obj) + continue + if "content" not in item: raise InvalidContextStructureError(f"Invalid context structure: {item}") diff --git a/api/dify_graph/nodes/tool/tool_node.py b/api/dify_graph/nodes/tool/tool_node.py index 44d0ca885d61d6..f2b998913897fb 100644 --- a/api/dify_graph/nodes/tool/tool_node.py +++ b/api/dify_graph/nodes/tool/tool_node.py @@ -1,3 +1,4 @@ +import logging from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any @@ -37,6 +38,9 @@ from dify_graph.runtime import GraphRuntimeState, VariablePool +logger = logging.getLogger(__name__) + + class ToolNode(Node[ToolNodeData]): """ Tool Node @@ -100,6 +104,7 @@ def _run(self) -> Generator[NodeEventBase, None, None]: variable_pool, ) except ToolNodeError as e: + logger.warning(e, exc_info=True) yield StreamCompletedEvent( node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -138,6 +143,7 @@ def _run(self) -> Generator[NodeEventBase, None, None]: conversation_id=conversation_id.text if conversation_id else None, ) except ToolNodeError as e: + logger.warning(e, exc_info=True) yield StreamCompletedEvent( node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -161,6 +167,7 @@ def _run(self) -> Generator[NodeEventBase, None, None]: tool_runtime=tool_runtime, ) except ToolInvokeError as e: + logger.warning(e, exc_info=True) yield StreamCompletedEvent( node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -171,6 +178,7 @@ def _run(self) -> Generator[NodeEventBase, None, None]: ) ) except PluginInvokeError as e: + logger.warning(e, exc_info=True) yield StreamCompletedEvent( node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index d56035b6bc5a7d..edaddf2b168aa0 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -786,6 +786,68 @@ def test_image_content_with_inline_data(self, llm_node_for_multimodal, monkeypat ) assert mock_saved_file in llm_node._file_outputs + +@pytest.mark.xfail(reason="_fetch_context should accept File objects in context arrays and treat them as context_files") +def test_fetch_context_accepts_file_objects_as_context_files(llm_node: LLMNode, graph_runtime_state: GraphRuntimeState): + # Enable context and point to a custom variable + llm_node.node_data.context.enabled = True + llm_node.node_data.context.variable_selector = ("nodeA", "ctx") + + # Build a File object and store it inside an Array segment via VariablePool.add + file_obj = File( + id=str(uuid.uuid4()), + tenant_id="1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id=str(uuid.uuid4()), + filename="img.png", + extension=".png", + mime_type="image/png", + size=10, + ) + graph_runtime_state.variable_pool.add(("nodeA", "ctx"), [file_obj]) + + # Run _fetch_context generator and collect event + gen = llm_node._fetch_context(llm_node.node_data) + events = list(gen) + + # Expect a single RunRetrieverResourceEvent with our file as context_files + assert len(events) == 1 + event = events[0] + assert hasattr(event, "context_files") + assert event.context_files + assert event.context_files[0].filename == "img.png" + + +@pytest.mark.xfail(reason="_fetch_context should accept file mapping dicts (dify_model_identity) as context_files") +def test_fetch_context_accepts_file_dicts_as_context_files(llm_node: LLMNode, graph_runtime_state: GraphRuntimeState): + from dify_graph.file.constants import FILE_MODEL_IDENTITY + + llm_node.node_data.context.enabled = True + llm_node.node_data.context.variable_selector = ("nodeB", "ctx") + + file_mapping = { + "dify_model_identity": FILE_MODEL_IDENTITY, + "tenant_id": "1", + "type": FileType.IMAGE.value, + "transfer_method": FileTransferMethod.LOCAL_FILE.value, + "related_id": str(uuid.uuid4()), + "filename": "pic.jpeg", + "extension": ".jpeg", + "mime_type": "image/jpeg", + "size": 12, + } + graph_runtime_state.variable_pool.add(("nodeB", "ctx"), [file_mapping]) + + gen = llm_node._fetch_context(llm_node.node_data) + events = list(gen) + + assert len(events) == 1 + event = events[0] + assert hasattr(event, "context_files") + assert event.context_files + assert event.context_files[0].filename == "pic.jpeg" + def test_unknown_content_type(self, llm_node_for_multimodal): llm_node, mock_file_saver = llm_node_for_multimodal gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(