Skip to content

Commit c2905f7

Browse files
committed
Make activation handler guess the layout based on tensor rank if missing
1 parent a8bcfcb commit c2905f7

File tree

1 file changed

+31
-7
lines changed

1 file changed

+31
-7
lines changed

src/finn/transformation/qonnx/qonnx_activation_handlers.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -402,12 +402,24 @@ def _calculate_thresholds(self):
402402
else:
403403
thresholds[c][t] = step / selu_scale
404404

405-
# First try to consider the tensor layout of the input for determining
406-
# the number of output channels
405+
# Get the shape of the input (should also be the output) tensor
406+
# Note: Querying the input is more safe as we do not want to
407+
# propagate shapes backwards by accident.
408+
shape = self._model.get_tensor_shape(self._q_node.input[0]) # noqa
409+
# First try to consider the tensor layout of the input for
410+
# determining the number of output channels
407411
layout = self._model.get_tensor_layout(self._q_node.input[0])
408-
# If there is a layout annotation, use this to determine the index of
409-
# the channel dimension
410-
if layout is not None and "C" in layout:
412+
# If there is no layout annotation, guess based on rank of the
413+
# tensor
414+
# TODO: No support for Rank >= 5
415+
if layout is None and len(shape) < 5:
416+
# Maps tensor rank to layout annotation
417+
rank_to_layout = {0: None, 1: "C", 2: "NC", 3: "NWC", 4: "NCHW"}
418+
# Lookup the layout required by this input shape
419+
layout = rank_to_layout[len(shape)]
420+
# If there is a layout annotation, use this to determine the index
421+
# of the channel dimension
422+
if layout is not None and "C" in layout: # noqa: Duplicate
411423
# Lookup the index in list
412424
cdim = layout.index("C")
413425
# If no layout has been annotated or there is no channel dimension, fall
@@ -570,12 +582,24 @@ def _calculate_thresholds(self):
570582
for t in range(num_thresholds):
571583
thresholds[c][t] = min_threshold[c] + step[c] * t
572584

585+
# Get the shape of the input (should also be the output) tensor
586+
# Note: Querying the input is more safe as we do not want to
587+
# propagate shapes backwards by accident.
588+
shape = self._model.get_tensor_shape(self._q_node.input[0])
573589
# First try to consider the tensor layout of the input for
574590
# determining the number of output channels
575-
layout = self._model.get_tensor_layout(self._q_node.input[0])
591+
layout = self._model.get_tensor_layout(self._q_node.input[0]) # noqa
592+
# If there is no layout annotation, guess based on rank of the
593+
# tensor
594+
# TODO: No support for Rank >= 5
595+
if layout is None and len(shape) < 5:
596+
# Maps tensor rank to layout annotation
597+
rank_to_layout = {0: None, 1: "C", 2: "NC", 3: "NWC", 4: "NCHW"}
598+
# Lookup the layout required by this input shape
599+
layout = rank_to_layout[len(shape)]
576600
# If there is a layout annotation, use this to determine the index
577601
# of the channel dimension
578-
if layout is not None and "C" in layout:
602+
if layout is not None and "C" in layout: # noqa: Duplicate
579603
# Lookup the index in list
580604
cdim = layout.index("C")
581605
# If no layout has been annotated or there is no channel dimension,

0 commit comments

Comments
 (0)