Skip to content

Commit cc2ec17

Browse files
authored
Merge pull request #169 from iksnagreb/feature/composed-transformation
[Transform] Introduce ComposedTransformation
2 parents f5c9819 + 4fc5070 commit cc2ec17

File tree

2 files changed

+71
-0
lines changed

2 files changed

+71
-0
lines changed

.isort.cfg

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,7 @@ sections=FUTURE,STDLIB,TEST,THIRDPARTY,FIRSTPARTY,LOCALFOLDER
99
default_section=THIRDPARTY
1010
multi_line_output=3
1111
profile=black
12+
ignore_comments=true
13+
ignore_whitespace=true
14+
honor_noqa=true
15+
use_parentheses=true
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import copy
2+
3+
# QONNX wrapper of ONNX model graphs
4+
from qonnx.core.modelwrapper import ModelWrapper
5+
6+
# QONNX graph transformations for annotating the graph with datatype and shape
7+
# information
8+
from qonnx.transformation.infer_datatypes import InferDataTypes
9+
from qonnx.transformation.infer_shapes import InferShapes
10+
11+
# Cleanup transformations removing identities like multiplication by one or
12+
# addition of zero
13+
from qonnx.transformation.remove import RemoveIdentityOps
14+
15+
# Base class for all QONNX graph transformations and some basic cleanup
16+
# transformations
17+
# fmt: off
18+
from qonnx.transformation.general import ( # isort: skip
19+
GiveReadableTensorNames, GiveUniqueNodeNames, Transformation
20+
)
21+
22+
23+
# fmt: on
24+
25+
26+
# Composes graph transformations such that each individual transformation as
27+
# well as the whole sequence is applied exhaustively
28+
class ComposedTransformation(Transformation):
29+
# Initializes the transformation given a list of transformations
30+
def __init__(self, transformations: list[Transformation]):
31+
super().__init__()
32+
# Register the list of transformations to be applied in apply()
33+
self.transformations = transformations
34+
35+
def apply(self, model: ModelWrapper): # noqa
36+
# Keep track of whether the graph has been modified
37+
graph_modified = False
38+
# Iterate all transformations to be applied
39+
for transformation in self.transformations:
40+
# Start each transformation on a deep copy of the model to mimic the
41+
# behavior of ModelWrapper.transform()
42+
model = copy.deepcopy(model)
43+
# Exhaustively apply the transformation until it no longer modifies
44+
# the graph
45+
while True:
46+
# Apply the transformation once, reporting back whether any node
47+
# or pattern has been modified
48+
model, _graph_modified = transformation.apply(model)
49+
# Keep track whether the graph has been modified at least once
50+
graph_modified = graph_modified or _graph_modified
51+
# Break the loop if this transformation did not change anything
52+
if not _graph_modified:
53+
break
54+
# Apply the default cleanup transformations of the ModelWrapper
55+
model.cleanup()
56+
# Apply some further cleanup transformations to the model graph
57+
# removing some clutter and keeping all names readable and ordered
58+
# at any time
59+
model = model.transform(RemoveIdentityOps())
60+
model = model.transform(GiveUniqueNodeNames())
61+
model = model.transform(GiveReadableTensorNames())
62+
model = model.transform(InferShapes())
63+
model = model.transform(InferDataTypes())
64+
# Return the transformed model and indicate whether the graph actually
65+
# has been transformed by at least one transformation so the whole
66+
# sequence of transformations will be reapplied
67+
return model, graph_modified

0 commit comments

Comments
 (0)