3434from onnx import NodeProto # noqa
3535from onnx import helper as oh
3636from qonnx .core .datatype import DataType
37-
38- # QONNX wrapper of ONNX model graphs
39- from qonnx .core .modelwrapper import ModelWrapper
4037from qonnx .custom_op .registry import getCustomOp
4138from qonnx .transformation .base import Transformation
4239from qonnx .transformation .infer_datatypes import InferDataTypes
@@ -106,19 +103,6 @@ def apply(self, model):
106103 return (model , graph_modified )
107104
108105
109- # Groups inputs by categories, i.e., groups dynamic inputs first, followed by
110- # initializers. Keeps order of inputs in each category.
111- def group_inputs_by_category (node : NodeProto , model : ModelWrapper ): # noqa
112- # First select all dynamic inputs, which are those without initializer
113- # tensor
114- dynamics = [i for i in node .input if model .get_initializer (i ) is None ]
115- # Select all input which are initializers, which, by exclusion, are all
116- # those not among the dynamic inputs
117- initializers = [i for i in node .input if i not in dynamics ]
118- # Return lists of dynamic anc initializer inputs
119- return dynamics , initializers
120-
121-
122106class AbsorbAddIntoMultiThreshold (Transformation ):
123107 """Absorb preceding Add ops into MultiThreshold by updating the threshold
124108 values. Only scalar/1D add vectors can be absorbed."""
@@ -132,17 +116,13 @@ def apply(self, model):
132116 if n .op_type == "Add" and not model .is_fork_node (n ) and not model .is_join_node (n ):
133117 consumer = model .find_consumer (n .output [0 ])
134118 if consumer is not None and consumer .op_type == "MultiThreshold" :
135- # As Add is not a join node, there must be one initializer
136- # and one dynamic input. We do not know their order, but
137- # can group them accordingly to extract the tensor names
138- (start ,), (add_weight ,) = group_inputs_by_category (n , model )
139- threshold = consumer .input [1 ]
140- A = model .get_initializer (add_weight )
141- T = model .get_initializer (threshold )
142- # Test for the thresholds actually being initializers
143- # Note: No need to validate the add_weights anymore, this
144- # is already handled by the grouping and is_join_node test.
119+ add_weight_name = n .input [1 ]
120+ threshold_name = consumer .input [1 ]
121+ A = model .get_initializer (add_weight_name )
122+ T = model .get_initializer (threshold_name )
123+ assert A is not None , "Initializer for add weights is not set."
145124 assert T is not None , "Initializer for thresholds is not set."
125+ start_name = n .input [0 ]
146126 # we can only absorb 0d or 1d adds
147127 is_scalar = A .ndim == 0 or all (x == 1 for x in A .shape )
148128 actual_ndims = len (tuple (filter (lambda x : x > 1 , A .shape )))
@@ -151,9 +131,9 @@ def apply(self, model):
151131 Tnew = T - A .reshape (- 1 , 1 )
152132 # Tnew = T - A.reshape(-1, T.shape[1])
153133 # compute new thresholds and set initializer
154- model .set_initializer (threshold , Tnew )
134+ model .set_initializer (threshold_name , Tnew )
155135 # wire add input directly to MultiThreshold
156- consumer .input [0 ] = start
136+ consumer .input [0 ] = start_name
157137 # remove the add node
158138 graph .node .remove (n )
159139 graph_modified = True
0 commit comments