Skip to content

Commit bfd4d70

Browse files
committed
fix: preserve context during child graph init
1 parent 358cf8b commit bfd4d70

3 files changed

Lines changed: 104 additions & 8 deletions

File tree

api/core/workflow/workflow_entry.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,14 @@ def build_child_engine(
112112
if has_root_node is False:
113113
raise ChildGraphNotFoundError(f"child graph root node '{root_node_id}' not found")
114114

115-
child_graph = Graph.init(
116-
graph_config=graph_config,
117-
node_factory=node_factory,
118-
root_node_id=root_node_id,
119-
)
115+
# Graph.init creates node instances immediately, and Dify node creation can
116+
# read Flask-bound services such as model provider configuration.
117+
with child_graph_runtime_state.execution_context:
118+
child_graph = Graph.init(
119+
graph_config=graph_config,
120+
node_factory=node_factory,
121+
root_node_id=root_node_id,
122+
)
120123

121124
command_channel = InMemoryChannel()
122125
config = GraphEngineConfig()

api/tests/unit_tests/core/workflow/test_workflow_entry.py

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
from types import SimpleNamespace
1+
from contextlib import AbstractContextManager
2+
from types import SimpleNamespace, TracebackType
23

34
import pytest
45

56
from configs import dify_config
67
from core.helper.code_executor.code_executor import CodeLanguage
8+
from core.workflow import workflow_entry
79
from core.workflow.system_variables import build_system_variables, default_system_variables
810
from core.workflow.variable_prefixes import (
911
CONVERSATION_VARIABLE_NODE_ID,
@@ -14,8 +16,9 @@
1416
from graphon.file import File, FileTransferMethod, FileType
1517
from graphon.nodes.code.code_node import CodeNode
1618
from graphon.nodes.code.limits import CodeNodeLimits
17-
from graphon.runtime import VariablePool
19+
from graphon.runtime import GraphRuntimeState, VariablePool
1820
from graphon.variables.variables import StringVariable
21+
from tests.workflow_test_utils import build_test_graph_init_params
1922

2023

2124
@pytest.fixture(autouse=True)
@@ -52,6 +55,96 @@ def fake_head(method, url, *args, **kwargs):
5255
class TestWorkflowEntry:
5356
"""Test WorkflowEntry class methods."""
5457

58+
def test_child_engine_enters_execution_context_while_initializing_graph(self, monkeypatch: pytest.MonkeyPatch):
59+
"""Child graph node factories should run inside the parent execution context."""
60+
61+
class RecordingExecutionContext(AbstractContextManager[None]):
62+
entered: bool
63+
was_entered_during_graph_init: bool
64+
65+
def __init__(self) -> None:
66+
self.entered = False
67+
self.was_entered_during_graph_init = False
68+
69+
def __enter__(self) -> None:
70+
self.entered = True
71+
72+
def __exit__(
73+
self,
74+
exc_type: type[BaseException] | None,
75+
exc_value: BaseException | None,
76+
traceback: TracebackType | None,
77+
) -> bool:
78+
self.entered = False
79+
return False
80+
81+
class StubDifyNodeFactory:
82+
graph_runtime_state: GraphRuntimeState
83+
84+
def __init__(self, *, graph_init_params: object, graph_runtime_state: GraphRuntimeState) -> None:
85+
self.graph_init_params = graph_init_params
86+
self.graph_runtime_state = graph_runtime_state
87+
created_runtime_states.append(graph_runtime_state)
88+
89+
class StubGraphEngine:
90+
graph_runtime_state: GraphRuntimeState
91+
layers: list[object]
92+
93+
def __init__(
94+
self,
95+
*,
96+
workflow_id: str,
97+
graph: object,
98+
graph_runtime_state: GraphRuntimeState,
99+
command_channel: object,
100+
config: object,
101+
child_engine_builder: object,
102+
) -> None:
103+
self.workflow_id = workflow_id
104+
self.graph = graph
105+
self.graph_runtime_state = graph_runtime_state
106+
self.command_channel = command_channel
107+
self.config = config
108+
self.child_engine_builder = child_engine_builder
109+
self.layers = []
110+
111+
def layer(self, layer: object) -> None:
112+
self.layers.append(layer)
113+
114+
created_runtime_states: list[GraphRuntimeState] = []
115+
execution_context = RecordingExecutionContext()
116+
parent_runtime_state = GraphRuntimeState(
117+
variable_pool=VariablePool(),
118+
start_at=0.0,
119+
execution_context=execution_context,
120+
)
121+
graph_init_params = build_test_graph_init_params(
122+
graph_config={"nodes": [{"id": "root"}], "edges": []},
123+
)
124+
125+
def init_graph(*, graph_config: object, node_factory: object, root_node_id: str) -> object:
126+
execution_context.was_entered_during_graph_init = execution_context.entered
127+
return {"graph_config": graph_config, "node_factory": node_factory, "root_node_id": root_node_id}
128+
129+
monkeypatch.setattr(workflow_entry, "DifyNodeFactory", StubDifyNodeFactory)
130+
monkeypatch.setattr(workflow_entry.Graph, "init", staticmethod(init_graph))
131+
monkeypatch.setattr(workflow_entry, "GraphEngine", StubGraphEngine)
132+
monkeypatch.setattr(workflow_entry, "LLMQuotaLayer", lambda tenant_id: ("quota", tenant_id))
133+
134+
engine = workflow_entry._WorkflowChildEngineBuilder(tenant_id="tenant-1").build_child_engine(
135+
workflow_id="workflow-1",
136+
graph_init_params=graph_init_params,
137+
parent_graph_runtime_state=parent_runtime_state,
138+
root_node_id="root",
139+
)
140+
141+
assert isinstance(engine, StubGraphEngine)
142+
assert execution_context.was_entered_during_graph_init is True
143+
assert execution_context.entered is False
144+
assert created_runtime_states[0].execution_context is execution_context
145+
assert engine.graph_runtime_state.execution_context is execution_context
146+
assert engine.layers == [("quota", "tenant-1")]
147+
55148
def test_mapping_user_inputs_to_variable_pool_with_system_variables(self):
56149
"""Test mapping system variables from user inputs to variable pool."""
57150
# Initialize variable pool with system variables

api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def test_build_child_engine_constructs_graph_engine_with_quota_layer_only(self):
125125
variable_pool=sentinel.parent_variable_pool,
126126
)
127127
child_graph = sentinel.child_graph
128-
child_graph_runtime_state = sentinel.child_graph_runtime_state
128+
child_graph_runtime_state = SimpleNamespace(execution_context=nullcontext(None))
129129
child_engine = MagicMock()
130130

131131
with (

0 commit comments

Comments
 (0)