Skip to content

Commit cae4dc7

Browse files
author
Yaman Umuroglu
authored
Merge pull request #187 from fastmachinelearning/feature/transform-subgraph-traveral
Add proper subgraph-traversal for qonnx model_wrapper transform function
2 parents ad48037 + 82ee6b0 commit cae4dc7

2 files changed

Lines changed: 283 additions & 15 deletions

File tree

src/qonnx/core/modelwrapper.py

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
# Copyright (c) 2020 Xilinx, Inc.
1+
# Copyright (c) 2022 - 2025 Advanced Micro Devices, Inc.
2+
# Copyright (c) 2020 - 2022 Xilinx, Inc.
23
# All rights reserved.
34
#
45
# Redistribution and use in source and binary forms, with or without
@@ -128,23 +129,58 @@ def analysis(self, analysis_fxn):
128129
"""Runs given anaylsis_fxn on this model and return resulting dict."""
129130
return analysis_fxn(self)
130131

131-
def transform(self, transformation, make_deepcopy=True, cleanup=True):
132+
def transform_subgraphs(self, transformation, make_deepcopy=True, cleanup=True, apply_to_subgraphs=False, use_preorder_traversal=True):
133+
"""Applies given Transformation to all subgraphs of this ModelWrapper instance.
134+
135+
- make_deepcopy : operates on a new (deep)copy of model.
136+
- cleanup : execute cleanup transformations before returning
137+
- apply_to_subgraphs : if True, transformation is applied to all subgraphs of the model
138+
- use_preorder_traversal : if True, uses preorder traversal for subgraph transformation,
139+
otherwise postorder traversal is used.
140+
"""
141+
for node in self.model.graph.node:
142+
transformed_subgraph_attrs = []
143+
for idx, attr in enumerate(node.attribute):
144+
if attr.type == onnx.AttributeProto.GRAPH:
145+
# this is a subgraph, add it to the list
146+
subgraph = self.make_subgraph_modelwrapper(attr.g)
147+
# apply the transformation to the subgraph
148+
subgraph = subgraph.transform(transformation, make_deepcopy, cleanup, apply_to_subgraphs, use_preorder_traversal)
149+
# update the new subgraph in the attrubute
150+
transformed_subgraph_attrs.append((idx, onnx.helper.make_attribute(attr.name, subgraph.model.graph)))
151+
# replace the attributes in the node with the transformed subgraph attributes
152+
for idx, new_attr in transformed_subgraph_attrs:
153+
# remove the old attribute
154+
node.attribute.pop(idx)
155+
# add the new attribute
156+
node.attribute.insert(idx, new_attr)
157+
158+
def transform(self, transformation, make_deepcopy=True, cleanup=True, apply_to_subgraphs=False, use_preorder_traversal=True):
132159
"""Applies given Transformation repeatedly until no more changes can be made
133160
and returns a transformed ModelWrapper instance.
134161
135162
- make_deepcopy : operates on a new (deep)copy of model.
136163
- cleanup : execute cleanup transformations before returning
164+
- apply_to_subgraphs : if True, transformation is applied to all subgraphs of the model
137165
"""
138166
transformed_model = self
139167
if make_deepcopy:
140168
transformed_model = copy.deepcopy(self)
141169
if self.fix_float64:
142170
(transformed_model, model_was_changed) = DoubleToSingleFloat().apply(transformed_model)
171+
172+
if apply_to_subgraphs and use_preorder_traversal == False:
173+
transformed_model.transform_subgraphs(transformation, make_deepcopy, cleanup, apply_to_subgraphs, use_preorder_traversal)
174+
143175
model_was_changed = True
144176
while model_was_changed:
145177
(transformed_model, model_was_changed) = transformation.apply(transformed_model)
146178
if cleanup:
147179
transformed_model.cleanup()
180+
181+
if apply_to_subgraphs and use_preorder_traversal:
182+
transformed_model.transform_subgraphs(transformation, make_deepcopy, cleanup, apply_to_subgraphs, use_preorder_traversal)
183+
148184
return transformed_model
149185

150186
def cleanup(self):
@@ -160,19 +196,8 @@ def cleanup(self):
160196
transformed_model = transformed_model.transform(trn, cleanup=False, make_deepcopy=False)
161197
return transformed_model
162198

163-
def check_compatibility(self):
164-
"""Checks this model for QONNX compatibility:
165-
166-
* no embedded subgraphs
167-
168-
* all tensor shapes are specified, including activations
169-
170-
* all constants are initializers
171-
"""
172-
# TODO check for no embedded subgraphs
173-
# TODO check that all shapes are inferred
174-
# TODO check that all constants are initializers
175-
return True
199+
def make_subgraph_modelwrapper(self, subgraph):
200+
return ModelWrapper(util.qonnx_make_model(subgraph, opset_imports=self._model_proto.opset_import))
176201

177202
def get_tensor_datatype(self, tensor_name):
178203
"""Returns the QONNX DataType of tensor with given name."""
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
import pytest
2+
from collections import Counter
3+
4+
from qonnx.core.modelwrapper import ModelWrapper
5+
from qonnx.transformation.base import Transformation
6+
7+
from qonnx.util.basic import qonnx_make_model, get_by_name
8+
import onnx
9+
from onnx import helper
10+
11+
# Helper to recursively build a graph with subgraphs attached to nodes
12+
def make_graph(tree):
13+
"""
14+
Recursively build a ModelWrapper tree from a nested tuple/list structure.
15+
Each graph will have one node per subgraph, with the subgraph attached as a node attribute.
16+
Example input:
17+
("top", [("sub1", []), ("sub2", [("sub2_1", [])])])
18+
Returns the top-level ModelWrapper.
19+
"""
20+
name, subtrees = tree
21+
# Create subgraphs recursively
22+
subgraph_nodes = []
23+
inputs = []
24+
outputs = []
25+
for subtree in subtrees:
26+
subgraph = make_graph(subtree)
27+
sg_name_in = f"{subgraph.name}_in"
28+
sg_name_out = f"{subgraph.name}_out"
29+
inputs.append(onnx.helper.make_tensor_value_info(sg_name_in, onnx.TensorProto.FLOAT, [4, 4]))
30+
outputs.append(onnx.helper.make_tensor_value_info(sg_name_out, onnx.TensorProto.FLOAT, [4, 4]))
31+
# Attach subgraph as attribute to node
32+
node = helper.make_node(
33+
op_type="SubgraphNode", # dummy op_type
34+
inputs=[sg_name_in],
35+
outputs=[sg_name_out],
36+
name=f"{subgraph.name}_node",
37+
)
38+
# ONNX expects subgraphs as AttributeProto, so we set it below
39+
attr = onnx.helper.make_attribute("body", subgraph)
40+
node.attribute.append(attr)
41+
subgraph_nodes.append(node)
42+
# Create the graph for this level
43+
graph = helper.make_graph(
44+
nodes=subgraph_nodes,
45+
name=name,
46+
inputs=inputs,
47+
outputs=outputs,
48+
)
49+
50+
return graph
51+
52+
def make_subgraph_model(tree):
53+
"""
54+
Build a ModelWrapper with a graph structure based on the provided tree.
55+
The tree is a nested tuple/list structure where each node can have subgraphs.
56+
"""
57+
return ModelWrapper(qonnx_make_model(make_graph(tree), opset_imports=[helper.make_opsetid("", 10)]))
58+
59+
60+
class DummyTransform(Transformation):
61+
def __init__(self):
62+
self.visited = list()
63+
64+
def apply(self, model_wrapper):
65+
# get the name of the graph being transformed
66+
graph_name = model_wrapper.model.graph.name
67+
# set a metadata property to test whether metadata is preserved
68+
model_wrapper.set_metadata_prop(graph_name, "visited")
69+
model_wrapper.set_metadata_prop("opset_id", str(model_wrapper.model.opset_import[0].version))
70+
# add a dummy node to the graph to simulate a transformation
71+
# to see if the subgraph transformation is presered
72+
73+
dummy_name_in = f"{graph_name}_dummy_in"
74+
dummy_name_out = f"{graph_name}_dummy_out"
75+
model_wrapper.model.graph.input.append(helper.make_tensor_value_info(dummy_name_in, onnx.TensorProto.FLOAT, [4, 4]))
76+
model_wrapper.model.graph.output.append(helper.make_tensor_value_info(dummy_name_out, onnx.TensorProto.FLOAT, [4, 4]))
77+
model_wrapper.model.graph.node.append(
78+
helper.make_node(
79+
"DummyNode", # dummy op_type
80+
inputs=[dummy_name_in],
81+
outputs=[dummy_name_out],
82+
name=f"{graph_name}_dummy_node",
83+
)
84+
)
85+
86+
# collect the name of the graph being transformed to check how many times each graph was visited
87+
self.visited.append(graph_name)
88+
#import pdb; pdb.set_trace()
89+
return model_wrapper, False
90+
91+
class NestedTransform(Transformation):
92+
def __init__(self):
93+
self.dummy_transform = DummyTransform()
94+
def apply(self, model_wrapper):
95+
return model_wrapper.transform(self.dummy_transform), False
96+
97+
def get_subgraph_names(tree):
98+
"""
99+
Recursively collect the names of all subgraphs in the tree structure.
100+
"""
101+
names = set()
102+
103+
def traverse(tree):
104+
name = tree[0]
105+
subgraphs = tree[1]
106+
names.add(name)
107+
for subgraph in subgraphs:
108+
traverse(subgraph)
109+
110+
traverse(tree)
111+
return names
112+
113+
114+
def check_all_visted_once(tree, transform):
115+
"""
116+
Check that all subgraphs in the tree structure were visited exactly once.
117+
"""
118+
visited = transform.visited
119+
expected = get_subgraph_names(tree)
120+
assert Counter(visited) == Counter(expected), f"Visited: {visited}, Expected: {expected}"
121+
122+
def check_visit_order(tree, transform, order):
123+
"""
124+
Check that the order of visited subgraphs matches the expected preorder or postorder traversal.
125+
"""
126+
visited = transform.visited
127+
expected = order(tree)
128+
assert visited == expected, f"Visited: {visited}, Expected: {expected}"
129+
130+
def check_all_subgraphs_transformed(graph):
131+
"""
132+
Check that all subgraphs in the tree structure have been transformed.
133+
"""
134+
135+
# look for the optype "DummyNode" in the model graph
136+
dummynode_found = False
137+
for node in graph.node:
138+
if node.op_type == "DummyNode":
139+
dummynode_found = True
140+
break
141+
if not dummynode_found:
142+
raise AssertionError(f"DummyNode not found in the transformed model graph {graph.name}")
143+
144+
# check that metadata is set for all subgraphs
145+
def get_metadata_props(graph, key):
146+
metadata_prop = get_by_name(graph.metadata_props, key, "key")
147+
if metadata_prop is None:
148+
return None
149+
else:
150+
return metadata_prop.value
151+
152+
assert(get_metadata_props(graph, graph.name) == "visited"), f"Metadata for {graph.name} not set correctly"
153+
assert(get_metadata_props(graph, "opset_id") == "10"), "Metadata for opset_id not set correctly"
154+
# recursively check all subgraphs
155+
for node in graph.node:
156+
for attr in node.attribute:
157+
if attr.type == onnx.AttributeProto.GRAPH:
158+
check_all_subgraphs_transformed(attr.g)
159+
160+
@pytest.mark.parametrize("cleanup", [False, True])
161+
@pytest.mark.parametrize("make_deepcopy", [False, True])
162+
@pytest.mark.parametrize("tree, apply_to_subgraphs",
163+
[(("top", []), True),
164+
(("top", []), False),
165+
(("top", [("sub1", [])]), False)])
166+
def test_no_traversal(tree, cleanup, make_deepcopy, apply_to_subgraphs):
167+
# Check that the top-level model is transformed exactly once when there are no subgraphs.
168+
# Check that the top-level model is transformed exactly once when there are subgraphs, but apply_to_subgraphs is False.
169+
# This should always be done correctly regardless of cleanup and make_deepcopy.
170+
171+
model = make_subgraph_model(tree)
172+
transform = DummyTransform()
173+
t_model = model.transform(transform, cleanup, make_deepcopy, apply_to_subgraphs)
174+
175+
assert transform.visited == ["top"]
176+
assert t_model.get_metadata_prop("top") == "visited"
177+
178+
def build_preorder_traversal(tree):
179+
"""
180+
Build a preorder traversal of the tree structure.
181+
"""
182+
traversal = []
183+
184+
def traverse(node):
185+
name, subtrees = node
186+
traversal.append(name)
187+
for subtree in subtrees:
188+
traverse(subtree)
189+
190+
traverse(tree)
191+
return traversal
192+
193+
def build_postorder_traversal(tree):
194+
"""
195+
Build a postorder traversal of the tree structure.
196+
"""
197+
traversal = []
198+
199+
def traverse(node):
200+
name, subtrees = node
201+
for subtree in subtrees:
202+
traverse(subtree)
203+
traversal.append(name)
204+
205+
traverse(tree)
206+
return traversal
207+
208+
@pytest.mark.parametrize("cleanup", [False, True])
209+
@pytest.mark.parametrize("make_deepcopy", [False, True])
210+
@pytest.mark.parametrize("tree", [("top", [("sub1", []), ("sub2", [])]),
211+
("top", [("sub1", [("sub1_1", []), ("sub1_2",[])]), ("sub2", [("sub2_1", [])])])])
212+
@pytest.mark.parametrize("use_preorder_traversal", [True, False])
213+
def test_traversal(tree, cleanup, make_deepcopy, use_preorder_traversal):
214+
# Check that the top-level model and all subgraphs are transformed when apply_to_subgraphs is True.
215+
# This should always be done correctly regardless of cleanup and make_deepcopy.
216+
print(f"Testing tree: {tree}, cleanup: {cleanup}, make_deepcopy: {make_deepcopy}")
217+
model = make_subgraph_model(tree)
218+
transform = DummyTransform()
219+
t_model = model.transform(transform, cleanup, make_deepcopy, apply_to_subgraphs=True, use_preorder_traversal=use_preorder_traversal)
220+
221+
check_all_visted_once(tree, transform)
222+
check_all_subgraphs_transformed(t_model.model.graph)
223+
224+
if use_preorder_traversal:
225+
traversal_order = build_preorder_traversal
226+
else:
227+
traversal_order = build_postorder_traversal
228+
check_visit_order(tree, transform, traversal_order)
229+
230+
231+
@pytest.mark.parametrize("cleanup", [False, True])
232+
@pytest.mark.parametrize("make_deepcopy", [False, True])
233+
@pytest.mark.parametrize("tree", [("top", [("sub1", []), ("sub2", [])]),
234+
("top", [("sub1", [("sub1_1", []), ("sub1_2",[])]), ("sub2", [("sub2_1", [])])])])
235+
def test_traversal_nested(tree, cleanup, make_deepcopy):
236+
# Check that the top-level model and all subgraphs are transformed when apply_to_subgraphs is True.
237+
# This should always be done correctly regardless of cleanup and make_deepcopy.
238+
model = make_subgraph_model(tree)
239+
transform = NestedTransform()
240+
t_model = model.transform(transform, cleanup, make_deepcopy, apply_to_subgraphs=True)
241+
242+
check_all_visted_once(tree, transform.dummy_transform)
243+
check_all_subgraphs_transformed(t_model.model.graph)

0 commit comments

Comments
 (0)