Skip to content

Commit c96b72b

Browse files
justinchubyCopilot
andauthored
Fix sort() on subgraphs (#347)
Graph.sort() didn't guard against predecessor nodes from outer-scope graphs. The fix adds a predecessor not in node_depth check to skip cross-graph predecessors in add_predecessor(). Added unit tests. --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent c1a0c23 commit c96b72b

File tree

2 files changed

+385
-0
lines changed

2 files changed

+385
-0
lines changed

src/onnx_ir/_core.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3547,6 +3547,9 @@ def add_predecessor(child: Node, predecessor: Node | None) -> None:
35473547
"""Add a predecessor of a node, and increment the depth of the predecessor."""
35483548
if predecessor is None:
35493549
return
3550+
if predecessor not in node_depth:
3551+
# Predecessor is from a different graph (e.g., outer scope); skip it.
3552+
return
35503553
node_predecessors[child].append(predecessor)
35513554
node_depth[predecessor] += 1
35523555

src/onnx_ir/_core_test.py

Lines changed: 382 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5074,5 +5074,387 @@ def test_clone_graph_view_with_intermediate_values(self):
50745074
self.assertEqual(full_graph[2].name, "mul")
50755075

50765076

5077+
def _sort_op_names(graph: ir.Graph) -> list[str]:
5078+
"""Return the list of op_type names in graph order."""
5079+
return [node.op_type for node in graph]
5080+
5081+
5082+
def _build_linear_graph(op_types: list[str], graph_name: str = "test") -> ir.Graph:
5083+
"""Build a linear chain: X -> op1 -> op2 -> ... -> Y."""
5084+
x = ir.Value(name="X")
5085+
prev = x
5086+
nodes = []
5087+
for i, op_type in enumerate(op_types):
5088+
node = ir.Node("", op_type, inputs=[prev], name=f"n{i}")
5089+
prev = node.outputs[0]
5090+
prev.name = f"v{i}"
5091+
nodes.append(node)
5092+
5093+
graph = ir.Graph(
5094+
inputs=[x],
5095+
outputs=[prev],
5096+
nodes=nodes,
5097+
name=graph_name,
5098+
opset_imports={"": 21},
5099+
)
5100+
return graph
5101+
5102+
5103+
class GraphSortTest(unittest.TestCase):
5104+
"""Tests for Graph.sort() topological sorting."""
5105+
5106+
def test_sort_empty_graph(self):
5107+
x = ir.Value(name="X")
5108+
graph = ir.Graph(
5109+
inputs=[x], outputs=[x], nodes=[], name="empty", opset_imports={"": 21}
5110+
)
5111+
graph.sort()
5112+
self.assertEqual(list(graph), [])
5113+
5114+
def test_sort_single_node(self):
5115+
graph = _build_linear_graph(["Relu"])
5116+
graph.sort()
5117+
self.assertEqual(_sort_op_names(graph), ["Relu"])
5118+
5119+
def test_sort_already_sorted_linear(self):
5120+
graph = _build_linear_graph(["Relu", "Sigmoid", "Tanh"])
5121+
graph.sort()
5122+
self.assertEqual(_sort_op_names(graph), ["Relu", "Sigmoid", "Tanh"])
5123+
5124+
def test_sort_reversed_linear(self):
5125+
"""Nodes in reverse order should be sorted to topological order."""
5126+
x = ir.Value(name="X")
5127+
n0 = ir.Node("", "Relu", inputs=[x], name="n0")
5128+
v0 = n0.outputs[0]
5129+
v0.name = "v0"
5130+
n1 = ir.Node("", "Sigmoid", inputs=[v0], name="n1")
5131+
v1 = n1.outputs[0]
5132+
v1.name = "v1"
5133+
n2 = ir.Node("", "Tanh", inputs=[v1], name="n2")
5134+
v2 = n2.outputs[0]
5135+
v2.name = "v2"
5136+
5137+
# Insert in reverse order
5138+
graph = ir.Graph(
5139+
inputs=[x],
5140+
outputs=[v2],
5141+
nodes=[n2, n1, n0],
5142+
name="test",
5143+
opset_imports={"": 21},
5144+
)
5145+
graph.sort()
5146+
self.assertEqual(_sort_op_names(graph), ["Relu", "Sigmoid", "Tanh"])
5147+
5148+
def test_sort_diamond_graph(self):
5149+
"""Diamond pattern: X -> A -> C, X -> B -> C."""
5150+
x = ir.Value(name="X")
5151+
a = ir.Node("", "Relu", inputs=[x], name="A")
5152+
va = a.outputs[0]
5153+
va.name = "va"
5154+
b = ir.Node("", "Sigmoid", inputs=[x], name="B")
5155+
vb = b.outputs[0]
5156+
vb.name = "vb"
5157+
c = ir.Node("", "Add", inputs=[va, vb], name="C")
5158+
vc = c.outputs[0]
5159+
vc.name = "vc"
5160+
5161+
# Insert out of order: C before A and B
5162+
graph = ir.Graph(
5163+
inputs=[x],
5164+
outputs=[vc],
5165+
nodes=[c, b, a],
5166+
name="test",
5167+
opset_imports={"": 21},
5168+
)
5169+
graph.sort()
5170+
ops = _sort_op_names(graph)
5171+
# C must come after both A and B
5172+
self.assertLess(ops.index("Relu"), ops.index("Add"))
5173+
self.assertLess(ops.index("Sigmoid"), ops.index("Add"))
5174+
5175+
def test_sort_with_none_inputs(self):
5176+
"""Nodes with None inputs should not cause errors."""
5177+
x = ir.Value(name="X")
5178+
node = ir.Node("", "Relu", inputs=[None, x], name="n0")
5179+
out = node.outputs[0]
5180+
out.name = "Y"
5181+
5182+
graph = ir.Graph(
5183+
inputs=[x],
5184+
outputs=[out],
5185+
nodes=[node],
5186+
name="test",
5187+
opset_imports={"": 21},
5188+
)
5189+
graph.sort()
5190+
self.assertEqual(_sort_op_names(graph), ["Relu"])
5191+
5192+
def test_sort_preserves_order_when_possible(self):
5193+
"""Sort is stable: independent nodes preserve original order."""
5194+
x = ir.Value(name="X")
5195+
a = ir.Node("", "Relu", inputs=[x], name="A")
5196+
va = a.outputs[0]
5197+
va.name = "va"
5198+
b = ir.Node("", "Sigmoid", inputs=[x], name="B")
5199+
vb = b.outputs[0]
5200+
vb.name = "vb"
5201+
c = ir.Node("", "Tanh", inputs=[x], name="C")
5202+
vc = c.outputs[0]
5203+
vc.name = "vc"
5204+
merge = ir.Node("", "Sum", inputs=[va, vb, vc], name="Merge", num_outputs=1)
5205+
out = merge.outputs[0]
5206+
out.name = "Y"
5207+
5208+
graph = ir.Graph(
5209+
inputs=[x],
5210+
outputs=[out],
5211+
nodes=[a, b, c, merge],
5212+
name="test",
5213+
opset_imports={"": 21},
5214+
)
5215+
graph.sort()
5216+
# Independent nodes A, B, C should keep their original order
5217+
self.assertEqual(_sort_op_names(graph), ["Relu", "Sigmoid", "Tanh", "Sum"])
5218+
5219+
def _make_if_graph(self, unsorted_subgraph: bool = False) -> ir.Graph:
5220+
"""Create a graph with an If node containing a subgraph."""
5221+
x = ir.Value(name="X")
5222+
cond_node = ir.Node("", "Relu", inputs=[x], name="cond")
5223+
cond_val = cond_node.outputs[0]
5224+
cond_val.name = "cond_val"
5225+
5226+
sub_in = ir.Value(name="sub_in")
5227+
sub_sig = ir.Node("", "Sigmoid", inputs=[sub_in], name="sub_sig")
5228+
sub_v = sub_sig.outputs[0]
5229+
sub_v.name = "sub_v"
5230+
sub_tanh = ir.Node("", "Tanh", inputs=[sub_v], name="sub_tanh")
5231+
sub_out = sub_tanh.outputs[0]
5232+
sub_out.name = "sub_out"
5233+
5234+
sub_nodes = [sub_tanh, sub_sig] if unsorted_subgraph else [sub_sig, sub_tanh]
5235+
subgraph = ir.Graph(
5236+
inputs=[sub_in],
5237+
outputs=[sub_out],
5238+
nodes=sub_nodes,
5239+
name="then_branch",
5240+
opset_imports={"": 21},
5241+
)
5242+
5243+
then_attr = ir.Attr("then_branch", ir.AttributeType.GRAPH, subgraph)
5244+
if_node = ir.Node("", "If", [cond_val], [then_attr], name="if_node")
5245+
if_out = if_node.outputs[0]
5246+
if_out.name = "Y"
5247+
5248+
graph = ir.Graph(
5249+
inputs=[x],
5250+
outputs=[if_out],
5251+
nodes=[if_node, cond_node], # Unsorted: if before cond
5252+
name="main",
5253+
opset_imports={"": 21},
5254+
)
5255+
return graph
5256+
5257+
def test_sort_with_graph_attribute(self):
5258+
"""Sort handles subgraphs in GRAPH attributes."""
5259+
graph = self._make_if_graph(unsorted_subgraph=True)
5260+
graph.sort()
5261+
5262+
self.assertEqual(_sort_op_names(graph), ["Relu", "If"])
5263+
5264+
if_node = list(graph)[1]
5265+
subgraph = if_node.attributes["then_branch"].value
5266+
self.assertEqual(_sort_op_names(subgraph), ["Sigmoid", "Tanh"])
5267+
5268+
def test_sort_with_graphs_attribute(self):
5269+
"""Sort handles subgraphs in GRAPHS attributes (multiple graphs)."""
5270+
x = ir.Value(name="X")
5271+
relu = ir.Node("", "Relu", inputs=[x], name="relu")
5272+
relu_out = relu.outputs[0]
5273+
relu_out.name = "relu_out"
5274+
5275+
sub_in1 = ir.Value(name="sub_in1")
5276+
sub_node1 = ir.Node("", "Sigmoid", inputs=[sub_in1], name="sub1")
5277+
sub_out1 = sub_node1.outputs[0]
5278+
sub_out1.name = "sub_out1"
5279+
sg1 = ir.Graph(
5280+
inputs=[sub_in1],
5281+
outputs=[sub_out1],
5282+
nodes=[sub_node1],
5283+
name="sg1",
5284+
opset_imports={"": 21},
5285+
)
5286+
5287+
sub_in2 = ir.Value(name="sub_in2")
5288+
sub_node2 = ir.Node("", "Tanh", inputs=[sub_in2], name="sub2")
5289+
sub_out2 = sub_node2.outputs[0]
5290+
sub_out2.name = "sub_out2"
5291+
sg2 = ir.Graph(
5292+
inputs=[sub_in2],
5293+
outputs=[sub_out2],
5294+
nodes=[sub_node2],
5295+
name="sg2",
5296+
opset_imports={"": 21},
5297+
)
5298+
5299+
graphs_attr = ir.Attr("branches", ir.AttributeType.GRAPHS, [sg1, sg2])
5300+
multi_node = ir.Node(
5301+
"",
5302+
"CustomMultiBranch",
5303+
[relu_out],
5304+
[graphs_attr],
5305+
name="multi",
5306+
)
5307+
multi_out = multi_node.outputs[0]
5308+
multi_out.name = "Y"
5309+
5310+
graph = ir.Graph(
5311+
inputs=[x],
5312+
outputs=[multi_out],
5313+
nodes=[multi_node, relu], # Unsorted
5314+
name="main",
5315+
opset_imports={"": 21},
5316+
)
5317+
graph.sort()
5318+
self.assertEqual(_sort_op_names(graph), ["Relu", "CustomMultiBranch"])
5319+
5320+
def test_sort_subgraph_with_outer_scope_input(self):
5321+
"""Subgraph node consuming a value produced in the parent graph should not crash."""
5322+
x = ir.Value(name="X")
5323+
relu = ir.Node("", "Relu", inputs=[x], name="relu")
5324+
relu_out = relu.outputs[0]
5325+
relu_out.name = "relu_out"
5326+
5327+
sub_in = ir.Value(name="sub_in")
5328+
sub_add = ir.Node("", "Add", inputs=[sub_in, relu_out], name="sub_add")
5329+
sub_out = sub_add.outputs[0]
5330+
sub_out.name = "sub_out"
5331+
sub_sig = ir.Node("", "Sigmoid", inputs=[sub_out], name="sub_sig")
5332+
sub_final = sub_sig.outputs[0]
5333+
sub_final.name = "sub_final"
5334+
5335+
subgraph = ir.Graph(
5336+
inputs=[sub_in],
5337+
outputs=[sub_final],
5338+
nodes=[sub_sig, sub_add], # Unsorted
5339+
name="body",
5340+
opset_imports={"": 21},
5341+
)
5342+
5343+
body_attr = ir.Attr("body", ir.AttributeType.GRAPH, subgraph)
5344+
loop_node = ir.Node("", "Loop", [relu_out], [body_attr], name="loop")
5345+
loop_out = loop_node.outputs[0]
5346+
loop_out.name = "Y"
5347+
5348+
graph = ir.Graph(
5349+
inputs=[x],
5350+
outputs=[loop_out],
5351+
nodes=[loop_node, relu], # Unsorted
5352+
name="main",
5353+
opset_imports={"": 21},
5354+
)
5355+
5356+
graph.sort()
5357+
5358+
self.assertEqual(_sort_op_names(graph), ["Relu", "Loop"])
5359+
5360+
body = loop_node.attributes["body"].value
5361+
self.assertEqual(_sort_op_names(body), ["Add", "Sigmoid"])
5362+
5363+
def test_sort_subgraph_directly_with_outer_scope_reference(self):
5364+
"""Calling sort() on a subgraph directly when it references outer-scope values."""
5365+
x = ir.Value(name="X")
5366+
relu = ir.Node("", "Relu", inputs=[x], name="relu")
5367+
relu_out = relu.outputs[0]
5368+
relu_out.name = "relu_out"
5369+
5370+
sub_in = ir.Value(name="sub_in")
5371+
sub_add = ir.Node("", "Add", inputs=[sub_in, relu_out], name="sub_add")
5372+
sub_v = sub_add.outputs[0]
5373+
sub_v.name = "sub_v"
5374+
sub_sig = ir.Node("", "Sigmoid", inputs=[sub_v], name="sub_sig")
5375+
sub_out = sub_sig.outputs[0]
5376+
sub_out.name = "sub_out"
5377+
5378+
subgraph = ir.Graph(
5379+
inputs=[sub_in],
5380+
outputs=[sub_out],
5381+
nodes=[sub_sig, sub_add], # Unsorted
5382+
name="body",
5383+
opset_imports={"": 21},
5384+
)
5385+
5386+
# Sort the subgraph directly — relu is NOT in this graph's nodes.
5387+
# This is the exact scenario that caused the original KeyError bug.
5388+
subgraph.sort()
5389+
self.assertEqual(_sort_op_names(subgraph), ["Add", "Sigmoid"])
5390+
5391+
def test_sort_deeply_nested_outer_scope(self):
5392+
"""A deeply nested subgraph referencing a grandparent value."""
5393+
x = ir.Value(name="X")
5394+
relu = ir.Node("", "Relu", inputs=[x], name="relu")
5395+
relu_out = relu.outputs[0]
5396+
relu_out.name = "relu_out"
5397+
5398+
inner_in = ir.Value(name="inner_in")
5399+
inner_add = ir.Node("", "Add", inputs=[inner_in, relu_out], name="inner_add")
5400+
inner_out = inner_add.outputs[0]
5401+
inner_out.name = "inner_out"
5402+
inner_graph = ir.Graph(
5403+
inputs=[inner_in],
5404+
outputs=[inner_out],
5405+
nodes=[inner_add],
5406+
name="inner",
5407+
opset_imports={"": 21},
5408+
)
5409+
5410+
mid_in = ir.Value(name="mid_in")
5411+
inner_attr = ir.Attr("body", ir.AttributeType.GRAPH, inner_graph)
5412+
mid_node = ir.Node("", "Loop", [mid_in], [inner_attr], name="mid_loop")
5413+
mid_out = mid_node.outputs[0]
5414+
mid_out.name = "mid_out"
5415+
mid_graph = ir.Graph(
5416+
inputs=[mid_in],
5417+
outputs=[mid_out],
5418+
nodes=[mid_node],
5419+
name="middle",
5420+
opset_imports={"": 21},
5421+
)
5422+
5423+
outer_attr = ir.Attr("body", ir.AttributeType.GRAPH, mid_graph)
5424+
outer_loop = ir.Node("", "Loop", [relu_out], [outer_attr], name="outer_loop")
5425+
outer_out = outer_loop.outputs[0]
5426+
outer_out.name = "Y"
5427+
5428+
graph = ir.Graph(
5429+
inputs=[x],
5430+
outputs=[outer_out],
5431+
nodes=[outer_loop, relu], # Unsorted
5432+
name="main",
5433+
opset_imports={"": 21},
5434+
)
5435+
5436+
graph.sort()
5437+
self.assertEqual(_sort_op_names(graph), ["Relu", "Loop"])
5438+
5439+
def test_sort_raises_on_cycle(self):
5440+
"""A graph with a cycle should raise ValueError."""
5441+
v_a = ir.Value(name="v_a")
5442+
v_b = ir.Value(name="v_b")
5443+
5444+
node_a = ir.Node("", "Relu", inputs=[v_b], name="A", outputs=[v_a])
5445+
node_b = ir.Node("", "Sigmoid", inputs=[v_a], name="B", outputs=[v_b])
5446+
5447+
x = ir.Value(name="X")
5448+
graph = ir.Graph(
5449+
inputs=[x],
5450+
outputs=[v_a],
5451+
nodes=[node_a, node_b],
5452+
name="cycle",
5453+
opset_imports={"": 21},
5454+
)
5455+
with self.assertRaisesRegex(ValueError, "cycle"):
5456+
graph.sort()
5457+
5458+
50775459
if __name__ == "__main__":
50785460
unittest.main()

0 commit comments

Comments
 (0)