@@ -5074,5 +5074,387 @@ def test_clone_graph_view_with_intermediate_values(self):
50745074 self .assertEqual (full_graph [2 ].name , "mul" )
50755075
50765076
5077+ def _sort_op_names (graph : ir .Graph ) -> list [str ]:
5078+ """Return the list of op_type names in graph order."""
5079+ return [node .op_type for node in graph ]
5080+
5081+
5082+ def _build_linear_graph (op_types : list [str ], graph_name : str = "test" ) -> ir .Graph :
5083+ """Build a linear chain: X -> op1 -> op2 -> ... -> Y."""
5084+ x = ir .Value (name = "X" )
5085+ prev = x
5086+ nodes = []
5087+ for i , op_type in enumerate (op_types ):
5088+ node = ir .Node ("" , op_type , inputs = [prev ], name = f"n{ i } " )
5089+ prev = node .outputs [0 ]
5090+ prev .name = f"v{ i } "
5091+ nodes .append (node )
5092+
5093+ graph = ir .Graph (
5094+ inputs = [x ],
5095+ outputs = [prev ],
5096+ nodes = nodes ,
5097+ name = graph_name ,
5098+ opset_imports = {"" : 21 },
5099+ )
5100+ return graph
5101+
5102+
5103+ class GraphSortTest (unittest .TestCase ):
5104+ """Tests for Graph.sort() topological sorting."""
5105+
5106+ def test_sort_empty_graph (self ):
5107+ x = ir .Value (name = "X" )
5108+ graph = ir .Graph (
5109+ inputs = [x ], outputs = [x ], nodes = [], name = "empty" , opset_imports = {"" : 21 }
5110+ )
5111+ graph .sort ()
5112+ self .assertEqual (list (graph ), [])
5113+
5114+ def test_sort_single_node (self ):
5115+ graph = _build_linear_graph (["Relu" ])
5116+ graph .sort ()
5117+ self .assertEqual (_sort_op_names (graph ), ["Relu" ])
5118+
5119+ def test_sort_already_sorted_linear (self ):
5120+ graph = _build_linear_graph (["Relu" , "Sigmoid" , "Tanh" ])
5121+ graph .sort ()
5122+ self .assertEqual (_sort_op_names (graph ), ["Relu" , "Sigmoid" , "Tanh" ])
5123+
5124+ def test_sort_reversed_linear (self ):
5125+ """Nodes in reverse order should be sorted to topological order."""
5126+ x = ir .Value (name = "X" )
5127+ n0 = ir .Node ("" , "Relu" , inputs = [x ], name = "n0" )
5128+ v0 = n0 .outputs [0 ]
5129+ v0 .name = "v0"
5130+ n1 = ir .Node ("" , "Sigmoid" , inputs = [v0 ], name = "n1" )
5131+ v1 = n1 .outputs [0 ]
5132+ v1 .name = "v1"
5133+ n2 = ir .Node ("" , "Tanh" , inputs = [v1 ], name = "n2" )
5134+ v2 = n2 .outputs [0 ]
5135+ v2 .name = "v2"
5136+
5137+ # Insert in reverse order
5138+ graph = ir .Graph (
5139+ inputs = [x ],
5140+ outputs = [v2 ],
5141+ nodes = [n2 , n1 , n0 ],
5142+ name = "test" ,
5143+ opset_imports = {"" : 21 },
5144+ )
5145+ graph .sort ()
5146+ self .assertEqual (_sort_op_names (graph ), ["Relu" , "Sigmoid" , "Tanh" ])
5147+
5148+ def test_sort_diamond_graph (self ):
5149+ """Diamond pattern: X -> A -> C, X -> B -> C."""
5150+ x = ir .Value (name = "X" )
5151+ a = ir .Node ("" , "Relu" , inputs = [x ], name = "A" )
5152+ va = a .outputs [0 ]
5153+ va .name = "va"
5154+ b = ir .Node ("" , "Sigmoid" , inputs = [x ], name = "B" )
5155+ vb = b .outputs [0 ]
5156+ vb .name = "vb"
5157+ c = ir .Node ("" , "Add" , inputs = [va , vb ], name = "C" )
5158+ vc = c .outputs [0 ]
5159+ vc .name = "vc"
5160+
5161+ # Insert out of order: C before A and B
5162+ graph = ir .Graph (
5163+ inputs = [x ],
5164+ outputs = [vc ],
5165+ nodes = [c , b , a ],
5166+ name = "test" ,
5167+ opset_imports = {"" : 21 },
5168+ )
5169+ graph .sort ()
5170+ ops = _sort_op_names (graph )
5171+ # C must come after both A and B
5172+ self .assertLess (ops .index ("Relu" ), ops .index ("Add" ))
5173+ self .assertLess (ops .index ("Sigmoid" ), ops .index ("Add" ))
5174+
5175+ def test_sort_with_none_inputs (self ):
5176+ """Nodes with None inputs should not cause errors."""
5177+ x = ir .Value (name = "X" )
5178+ node = ir .Node ("" , "Relu" , inputs = [None , x ], name = "n0" )
5179+ out = node .outputs [0 ]
5180+ out .name = "Y"
5181+
5182+ graph = ir .Graph (
5183+ inputs = [x ],
5184+ outputs = [out ],
5185+ nodes = [node ],
5186+ name = "test" ,
5187+ opset_imports = {"" : 21 },
5188+ )
5189+ graph .sort ()
5190+ self .assertEqual (_sort_op_names (graph ), ["Relu" ])
5191+
5192+ def test_sort_preserves_order_when_possible (self ):
5193+ """Sort is stable: independent nodes preserve original order."""
5194+ x = ir .Value (name = "X" )
5195+ a = ir .Node ("" , "Relu" , inputs = [x ], name = "A" )
5196+ va = a .outputs [0 ]
5197+ va .name = "va"
5198+ b = ir .Node ("" , "Sigmoid" , inputs = [x ], name = "B" )
5199+ vb = b .outputs [0 ]
5200+ vb .name = "vb"
5201+ c = ir .Node ("" , "Tanh" , inputs = [x ], name = "C" )
5202+ vc = c .outputs [0 ]
5203+ vc .name = "vc"
5204+ merge = ir .Node ("" , "Sum" , inputs = [va , vb , vc ], name = "Merge" , num_outputs = 1 )
5205+ out = merge .outputs [0 ]
5206+ out .name = "Y"
5207+
5208+ graph = ir .Graph (
5209+ inputs = [x ],
5210+ outputs = [out ],
5211+ nodes = [a , b , c , merge ],
5212+ name = "test" ,
5213+ opset_imports = {"" : 21 },
5214+ )
5215+ graph .sort ()
5216+ # Independent nodes A, B, C should keep their original order
5217+ self .assertEqual (_sort_op_names (graph ), ["Relu" , "Sigmoid" , "Tanh" , "Sum" ])
5218+
5219+ def _make_if_graph (self , unsorted_subgraph : bool = False ) -> ir .Graph :
5220+ """Create a graph with an If node containing a subgraph."""
5221+ x = ir .Value (name = "X" )
5222+ cond_node = ir .Node ("" , "Relu" , inputs = [x ], name = "cond" )
5223+ cond_val = cond_node .outputs [0 ]
5224+ cond_val .name = "cond_val"
5225+
5226+ sub_in = ir .Value (name = "sub_in" )
5227+ sub_sig = ir .Node ("" , "Sigmoid" , inputs = [sub_in ], name = "sub_sig" )
5228+ sub_v = sub_sig .outputs [0 ]
5229+ sub_v .name = "sub_v"
5230+ sub_tanh = ir .Node ("" , "Tanh" , inputs = [sub_v ], name = "sub_tanh" )
5231+ sub_out = sub_tanh .outputs [0 ]
5232+ sub_out .name = "sub_out"
5233+
5234+ sub_nodes = [sub_tanh , sub_sig ] if unsorted_subgraph else [sub_sig , sub_tanh ]
5235+ subgraph = ir .Graph (
5236+ inputs = [sub_in ],
5237+ outputs = [sub_out ],
5238+ nodes = sub_nodes ,
5239+ name = "then_branch" ,
5240+ opset_imports = {"" : 21 },
5241+ )
5242+
5243+ then_attr = ir .Attr ("then_branch" , ir .AttributeType .GRAPH , subgraph )
5244+ if_node = ir .Node ("" , "If" , [cond_val ], [then_attr ], name = "if_node" )
5245+ if_out = if_node .outputs [0 ]
5246+ if_out .name = "Y"
5247+
5248+ graph = ir .Graph (
5249+ inputs = [x ],
5250+ outputs = [if_out ],
5251+ nodes = [if_node , cond_node ], # Unsorted: if before cond
5252+ name = "main" ,
5253+ opset_imports = {"" : 21 },
5254+ )
5255+ return graph
5256+
5257+ def test_sort_with_graph_attribute (self ):
5258+ """Sort handles subgraphs in GRAPH attributes."""
5259+ graph = self ._make_if_graph (unsorted_subgraph = True )
5260+ graph .sort ()
5261+
5262+ self .assertEqual (_sort_op_names (graph ), ["Relu" , "If" ])
5263+
5264+ if_node = list (graph )[1 ]
5265+ subgraph = if_node .attributes ["then_branch" ].value
5266+ self .assertEqual (_sort_op_names (subgraph ), ["Sigmoid" , "Tanh" ])
5267+
5268+ def test_sort_with_graphs_attribute (self ):
5269+ """Sort handles subgraphs in GRAPHS attributes (multiple graphs)."""
5270+ x = ir .Value (name = "X" )
5271+ relu = ir .Node ("" , "Relu" , inputs = [x ], name = "relu" )
5272+ relu_out = relu .outputs [0 ]
5273+ relu_out .name = "relu_out"
5274+
5275+ sub_in1 = ir .Value (name = "sub_in1" )
5276+ sub_node1 = ir .Node ("" , "Sigmoid" , inputs = [sub_in1 ], name = "sub1" )
5277+ sub_out1 = sub_node1 .outputs [0 ]
5278+ sub_out1 .name = "sub_out1"
5279+ sg1 = ir .Graph (
5280+ inputs = [sub_in1 ],
5281+ outputs = [sub_out1 ],
5282+ nodes = [sub_node1 ],
5283+ name = "sg1" ,
5284+ opset_imports = {"" : 21 },
5285+ )
5286+
5287+ sub_in2 = ir .Value (name = "sub_in2" )
5288+ sub_node2 = ir .Node ("" , "Tanh" , inputs = [sub_in2 ], name = "sub2" )
5289+ sub_out2 = sub_node2 .outputs [0 ]
5290+ sub_out2 .name = "sub_out2"
5291+ sg2 = ir .Graph (
5292+ inputs = [sub_in2 ],
5293+ outputs = [sub_out2 ],
5294+ nodes = [sub_node2 ],
5295+ name = "sg2" ,
5296+ opset_imports = {"" : 21 },
5297+ )
5298+
5299+ graphs_attr = ir .Attr ("branches" , ir .AttributeType .GRAPHS , [sg1 , sg2 ])
5300+ multi_node = ir .Node (
5301+ "" ,
5302+ "CustomMultiBranch" ,
5303+ [relu_out ],
5304+ [graphs_attr ],
5305+ name = "multi" ,
5306+ )
5307+ multi_out = multi_node .outputs [0 ]
5308+ multi_out .name = "Y"
5309+
5310+ graph = ir .Graph (
5311+ inputs = [x ],
5312+ outputs = [multi_out ],
5313+ nodes = [multi_node , relu ], # Unsorted
5314+ name = "main" ,
5315+ opset_imports = {"" : 21 },
5316+ )
5317+ graph .sort ()
5318+ self .assertEqual (_sort_op_names (graph ), ["Relu" , "CustomMultiBranch" ])
5319+
5320+ def test_sort_subgraph_with_outer_scope_input (self ):
5321+ """Subgraph node consuming a value produced in the parent graph should not crash."""
5322+ x = ir .Value (name = "X" )
5323+ relu = ir .Node ("" , "Relu" , inputs = [x ], name = "relu" )
5324+ relu_out = relu .outputs [0 ]
5325+ relu_out .name = "relu_out"
5326+
5327+ sub_in = ir .Value (name = "sub_in" )
5328+ sub_add = ir .Node ("" , "Add" , inputs = [sub_in , relu_out ], name = "sub_add" )
5329+ sub_out = sub_add .outputs [0 ]
5330+ sub_out .name = "sub_out"
5331+ sub_sig = ir .Node ("" , "Sigmoid" , inputs = [sub_out ], name = "sub_sig" )
5332+ sub_final = sub_sig .outputs [0 ]
5333+ sub_final .name = "sub_final"
5334+
5335+ subgraph = ir .Graph (
5336+ inputs = [sub_in ],
5337+ outputs = [sub_final ],
5338+ nodes = [sub_sig , sub_add ], # Unsorted
5339+ name = "body" ,
5340+ opset_imports = {"" : 21 },
5341+ )
5342+
5343+ body_attr = ir .Attr ("body" , ir .AttributeType .GRAPH , subgraph )
5344+ loop_node = ir .Node ("" , "Loop" , [relu_out ], [body_attr ], name = "loop" )
5345+ loop_out = loop_node .outputs [0 ]
5346+ loop_out .name = "Y"
5347+
5348+ graph = ir .Graph (
5349+ inputs = [x ],
5350+ outputs = [loop_out ],
5351+ nodes = [loop_node , relu ], # Unsorted
5352+ name = "main" ,
5353+ opset_imports = {"" : 21 },
5354+ )
5355+
5356+ graph .sort ()
5357+
5358+ self .assertEqual (_sort_op_names (graph ), ["Relu" , "Loop" ])
5359+
5360+ body = loop_node .attributes ["body" ].value
5361+ self .assertEqual (_sort_op_names (body ), ["Add" , "Sigmoid" ])
5362+
5363+ def test_sort_subgraph_directly_with_outer_scope_reference (self ):
5364+ """Calling sort() on a subgraph directly when it references outer-scope values."""
5365+ x = ir .Value (name = "X" )
5366+ relu = ir .Node ("" , "Relu" , inputs = [x ], name = "relu" )
5367+ relu_out = relu .outputs [0 ]
5368+ relu_out .name = "relu_out"
5369+
5370+ sub_in = ir .Value (name = "sub_in" )
5371+ sub_add = ir .Node ("" , "Add" , inputs = [sub_in , relu_out ], name = "sub_add" )
5372+ sub_v = sub_add .outputs [0 ]
5373+ sub_v .name = "sub_v"
5374+ sub_sig = ir .Node ("" , "Sigmoid" , inputs = [sub_v ], name = "sub_sig" )
5375+ sub_out = sub_sig .outputs [0 ]
5376+ sub_out .name = "sub_out"
5377+
5378+ subgraph = ir .Graph (
5379+ inputs = [sub_in ],
5380+ outputs = [sub_out ],
5381+ nodes = [sub_sig , sub_add ], # Unsorted
5382+ name = "body" ,
5383+ opset_imports = {"" : 21 },
5384+ )
5385+
5386+ # Sort the subgraph directly — relu is NOT in this graph's nodes.
5387+ # This is the exact scenario that caused the original KeyError bug.
5388+ subgraph .sort ()
5389+ self .assertEqual (_sort_op_names (subgraph ), ["Add" , "Sigmoid" ])
5390+
5391+ def test_sort_deeply_nested_outer_scope (self ):
5392+ """A deeply nested subgraph referencing a grandparent value."""
5393+ x = ir .Value (name = "X" )
5394+ relu = ir .Node ("" , "Relu" , inputs = [x ], name = "relu" )
5395+ relu_out = relu .outputs [0 ]
5396+ relu_out .name = "relu_out"
5397+
5398+ inner_in = ir .Value (name = "inner_in" )
5399+ inner_add = ir .Node ("" , "Add" , inputs = [inner_in , relu_out ], name = "inner_add" )
5400+ inner_out = inner_add .outputs [0 ]
5401+ inner_out .name = "inner_out"
5402+ inner_graph = ir .Graph (
5403+ inputs = [inner_in ],
5404+ outputs = [inner_out ],
5405+ nodes = [inner_add ],
5406+ name = "inner" ,
5407+ opset_imports = {"" : 21 },
5408+ )
5409+
5410+ mid_in = ir .Value (name = "mid_in" )
5411+ inner_attr = ir .Attr ("body" , ir .AttributeType .GRAPH , inner_graph )
5412+ mid_node = ir .Node ("" , "Loop" , [mid_in ], [inner_attr ], name = "mid_loop" )
5413+ mid_out = mid_node .outputs [0 ]
5414+ mid_out .name = "mid_out"
5415+ mid_graph = ir .Graph (
5416+ inputs = [mid_in ],
5417+ outputs = [mid_out ],
5418+ nodes = [mid_node ],
5419+ name = "middle" ,
5420+ opset_imports = {"" : 21 },
5421+ )
5422+
5423+ outer_attr = ir .Attr ("body" , ir .AttributeType .GRAPH , mid_graph )
5424+ outer_loop = ir .Node ("" , "Loop" , [relu_out ], [outer_attr ], name = "outer_loop" )
5425+ outer_out = outer_loop .outputs [0 ]
5426+ outer_out .name = "Y"
5427+
5428+ graph = ir .Graph (
5429+ inputs = [x ],
5430+ outputs = [outer_out ],
5431+ nodes = [outer_loop , relu ], # Unsorted
5432+ name = "main" ,
5433+ opset_imports = {"" : 21 },
5434+ )
5435+
5436+ graph .sort ()
5437+ self .assertEqual (_sort_op_names (graph ), ["Relu" , "Loop" ])
5438+
5439+ def test_sort_raises_on_cycle (self ):
5440+ """A graph with a cycle should raise ValueError."""
5441+ v_a = ir .Value (name = "v_a" )
5442+ v_b = ir .Value (name = "v_b" )
5443+
5444+ node_a = ir .Node ("" , "Relu" , inputs = [v_b ], name = "A" , outputs = [v_a ])
5445+ node_b = ir .Node ("" , "Sigmoid" , inputs = [v_a ], name = "B" , outputs = [v_b ])
5446+
5447+ x = ir .Value (name = "X" )
5448+ graph = ir .Graph (
5449+ inputs = [x ],
5450+ outputs = [v_a ],
5451+ nodes = [node_a , node_b ],
5452+ name = "cycle" ,
5453+ opset_imports = {"" : 21 },
5454+ )
5455+ with self .assertRaisesRegex (ValueError , "cycle" ):
5456+ graph .sort ()
5457+
5458+
50775459if __name__ == "__main__" :
50785460 unittest .main ()
0 commit comments