Skip to content

Commit 42ca6ca

Browse files
committed
[Transform] Remove broadcasting and grouping inputs in AbsorbAddIntoMultiThreshold
1 parent 4ff96ba commit 42ca6ca

2 files changed

Lines changed: 8 additions & 36 deletions

File tree

src/finn/transformation/qonnx/qonnx_activation_handlers.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -422,10 +422,6 @@ def _calculate_thresholds(self):
422422
), """Quant node cannot be converted to MultiThreshold because only
423423
per tensor or per channel quantization supported."""
424424

425-
final_shape = (num_output_channels, num_thresholds)
426-
if thresholds.shape != final_shape:
427-
thresholds = np.broadcast_to(thresholds, final_shape)
428-
429425
return thresholds
430426

431427
def _calculate_act_scale(self):
@@ -589,10 +585,6 @@ def _calculate_thresholds(self):
589585
), """Quant node cannot be converted to MultiThreshold because only
590586
per tensor or per channel quantization supported."""
591587

592-
final_shape = (num_output_channels, num_thresholds)
593-
if thresholds.shape != final_shape:
594-
thresholds = np.broadcast_to(thresholds, final_shape)
595-
596588
return thresholds
597589

598590
def _calculate_act_scale(self):

src/finn/transformation/streamline/absorb.py

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,6 @@
3434
from onnx import NodeProto # noqa
3535
from onnx import helper as oh
3636
from qonnx.core.datatype import DataType
37-
38-
# QONNX wrapper of ONNX model graphs
39-
from qonnx.core.modelwrapper import ModelWrapper
4037
from qonnx.custom_op.registry import getCustomOp
4138
from qonnx.transformation.base import Transformation
4239
from qonnx.transformation.infer_datatypes import InferDataTypes
@@ -106,19 +103,6 @@ def apply(self, model):
106103
return (model, graph_modified)
107104

108105

109-
# Groups inputs by categories, i.e., groups dynamic inputs first, followed by
110-
# initializers. Keeps order of inputs in each category.
111-
def group_inputs_by_category(node: NodeProto, model: ModelWrapper): # noqa
112-
# First select all dynamic inputs, which are those without initializer
113-
# tensor
114-
dynamics = [i for i in node.input if model.get_initializer(i) is None]
115-
# Select all input which are initializers, which, by exclusion, are all
116-
# those not among the dynamic inputs
117-
initializers = [i for i in node.input if i not in dynamics]
118-
# Return lists of dynamic anc initializer inputs
119-
return dynamics, initializers
120-
121-
122106
class AbsorbAddIntoMultiThreshold(Transformation):
123107
"""Absorb preceding Add ops into MultiThreshold by updating the threshold
124108
values. Only scalar/1D add vectors can be absorbed."""
@@ -132,17 +116,13 @@ def apply(self, model):
132116
if n.op_type == "Add" and not model.is_fork_node(n) and not model.is_join_node(n):
133117
consumer = model.find_consumer(n.output[0])
134118
if consumer is not None and consumer.op_type == "MultiThreshold":
135-
# As Add is not a join node, there must be one initializer
136-
# and one dynamic input. We do not know their order, but
137-
# can group them accordingly to extract the tensor names
138-
(start,), (add_weight,) = group_inputs_by_category(n, model)
139-
threshold = consumer.input[1]
140-
A = model.get_initializer(add_weight)
141-
T = model.get_initializer(threshold)
142-
# Test for the thresholds actually being initializers
143-
# Note: No need to validate the add_weights anymore, this
144-
# is already handled by the grouping and is_join_node test.
119+
add_weight_name = n.input[1]
120+
threshold_name = consumer.input[1]
121+
A = model.get_initializer(add_weight_name)
122+
T = model.get_initializer(threshold_name)
123+
assert A is not None, "Initializer for add weights is not set."
145124
assert T is not None, "Initializer for thresholds is not set."
125+
start_name = n.input[0]
146126
# we can only absorb 0d or 1d adds
147127
is_scalar = A.ndim == 0 or all(x == 1 for x in A.shape)
148128
actual_ndims = len(tuple(filter(lambda x: x > 1, A.shape)))
@@ -151,9 +131,9 @@ def apply(self, model):
151131
Tnew = T - A.reshape(-1, 1)
152132
# Tnew = T - A.reshape(-1, T.shape[1])
153133
# compute new thresholds and set initializer
154-
model.set_initializer(threshold, Tnew)
134+
model.set_initializer(threshold_name, Tnew)
155135
# wire add input directly to MultiThreshold
156-
consumer.input[0] = start
136+
consumer.input[0] = start_name
157137
# remove the add node
158138
graph.node.remove(n)
159139
graph_modified = True

0 commit comments

Comments
 (0)