Skip to content

Commit 92cf6d1

Browse files
committed
fix: preserve context during child graph init
1 parent 358cf8b commit 92cf6d1

3 files changed

Lines changed: 107 additions & 8 deletions

File tree

api/core/workflow/workflow_entry.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,14 @@ def build_child_engine(
112112
if has_root_node is False:
113113
raise ChildGraphNotFoundError(f"child graph root node '{root_node_id}' not found")
114114

115-
child_graph = Graph.init(
116-
graph_config=graph_config,
117-
node_factory=node_factory,
118-
root_node_id=root_node_id,
119-
)
115+
# Graph.init creates node instances immediately, and Dify node creation can
116+
# read Flask-bound services such as model provider configuration.
117+
with child_graph_runtime_state.execution_context:
118+
child_graph = Graph.init(
119+
graph_config=graph_config,
120+
node_factory=node_factory,
121+
root_node_id=root_node_id,
122+
)
120123

121124
command_channel = InMemoryChannel()
122125
config = GraphEngineConfig()

api/tests/unit_tests/core/workflow/test_workflow_entry.py

Lines changed: 98 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1-
from types import SimpleNamespace
1+
from contextlib import AbstractContextManager
2+
from types import SimpleNamespace, TracebackType
3+
from typing import override
24

35
import pytest
46

57
from configs import dify_config
68
from core.helper.code_executor.code_executor import CodeLanguage
9+
from core.workflow import workflow_entry
710
from core.workflow.system_variables import build_system_variables, default_system_variables
811
from core.workflow.variable_prefixes import (
912
CONVERSATION_VARIABLE_NODE_ID,
@@ -14,8 +17,9 @@
1417
from graphon.file import File, FileTransferMethod, FileType
1518
from graphon.nodes.code.code_node import CodeNode
1619
from graphon.nodes.code.limits import CodeNodeLimits
17-
from graphon.runtime import VariablePool
20+
from graphon.runtime import GraphRuntimeState, VariablePool
1821
from graphon.variables.variables import StringVariable
22+
from tests.workflow_test_utils import build_test_graph_init_params
1923

2024

2125
@pytest.fixture(autouse=True)
@@ -52,6 +56,98 @@ def fake_head(method, url, *args, **kwargs):
5256
class TestWorkflowEntry:
5357
"""Test WorkflowEntry class methods."""
5458

59+
def test_child_engine_enters_execution_context_while_initializing_graph(self, monkeypatch: pytest.MonkeyPatch):
60+
"""Child graph node factories should run inside the parent execution context."""
61+
62+
class RecordingExecutionContext(AbstractContextManager[None]):
63+
entered: bool
64+
was_entered_during_graph_init: bool
65+
66+
def __init__(self) -> None:
67+
self.entered = False
68+
self.was_entered_during_graph_init = False
69+
70+
@override
71+
def __enter__(self) -> None:
72+
self.entered = True
73+
74+
@override
75+
def __exit__(
76+
self,
77+
exc_type: type[BaseException] | None,
78+
exc_value: BaseException | None,
79+
traceback: TracebackType | None,
80+
) -> bool:
81+
self.entered = False
82+
return False
83+
84+
class StubDifyNodeFactory:
85+
graph_runtime_state: GraphRuntimeState
86+
87+
def __init__(self, *, graph_init_params: object, graph_runtime_state: GraphRuntimeState) -> None:
88+
self.graph_init_params = graph_init_params
89+
self.graph_runtime_state = graph_runtime_state
90+
created_runtime_states.append(graph_runtime_state)
91+
92+
class StubGraphEngine:
93+
graph_runtime_state: GraphRuntimeState
94+
layers: list[object]
95+
96+
def __init__(
97+
self,
98+
*,
99+
workflow_id: str,
100+
graph: object,
101+
graph_runtime_state: GraphRuntimeState,
102+
command_channel: object,
103+
config: object,
104+
child_engine_builder: object,
105+
) -> None:
106+
self.workflow_id = workflow_id
107+
self.graph = graph
108+
self.graph_runtime_state = graph_runtime_state
109+
self.command_channel = command_channel
110+
self.config = config
111+
self.child_engine_builder = child_engine_builder
112+
self.layers = []
113+
114+
def layer(self, layer: object) -> None:
115+
self.layers.append(layer)
116+
117+
created_runtime_states: list[GraphRuntimeState] = []
118+
execution_context = RecordingExecutionContext()
119+
parent_runtime_state = GraphRuntimeState(
120+
variable_pool=VariablePool(),
121+
start_at=0.0,
122+
execution_context=execution_context,
123+
)
124+
graph_init_params = build_test_graph_init_params(
125+
graph_config={"nodes": [{"id": "root"}], "edges": []},
126+
)
127+
128+
def init_graph(*, graph_config: object, node_factory: object, root_node_id: str) -> object:
129+
execution_context.was_entered_during_graph_init = execution_context.entered
130+
return {"graph_config": graph_config, "node_factory": node_factory, "root_node_id": root_node_id}
131+
132+
monkeypatch.setattr(workflow_entry, "DifyNodeFactory", StubDifyNodeFactory)
133+
monkeypatch.setattr(workflow_entry.Graph, "init", staticmethod(init_graph))
134+
monkeypatch.setattr(workflow_entry, "GraphEngine", StubGraphEngine)
135+
monkeypatch.setattr(workflow_entry, "LLMQuotaLayer", lambda tenant_id: ("quota", tenant_id))
136+
137+
engine = workflow_entry._WorkflowChildEngineBuilder(tenant_id="tenant-1").build_child_engine(
138+
workflow_id="workflow-1",
139+
graph_init_params=graph_init_params,
140+
parent_graph_runtime_state=parent_runtime_state,
141+
root_node_id="root",
142+
)
143+
144+
assert isinstance(engine, StubGraphEngine)
145+
assert execution_context.was_entered_during_graph_init is True
146+
assert execution_context.entered is False
147+
assert created_runtime_states[0].execution_context is execution_context
148+
assert engine.graph_runtime_state.execution_context is execution_context
149+
assert engine.layers == [("quota", "tenant-1")]
150+
55151
def test_mapping_user_inputs_to_variable_pool_with_system_variables(self):
56152
"""Test mapping system variables from user inputs to variable pool."""
57153
# Initialize variable pool with system variables

api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def test_build_child_engine_constructs_graph_engine_with_quota_layer_only(self):
125125
variable_pool=sentinel.parent_variable_pool,
126126
)
127127
child_graph = sentinel.child_graph
128-
child_graph_runtime_state = sentinel.child_graph_runtime_state
128+
child_graph_runtime_state = SimpleNamespace(execution_context=nullcontext(None))
129129
child_engine = MagicMock()
130130

131131
with (

0 commit comments

Comments
 (0)