2525# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
2626# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2727# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28-
2928import numpy as np
29+ import warnings
3030from abc import ABC , abstractmethod
3131from onnx import TensorProto , helper
3232from qonnx .core .modelwrapper import ModelWrapper
@@ -70,7 +70,7 @@ def _check_compatibility(self):
7070 @abstractmethod
7171 def _calculate_act_bias (self ):
7272 """Calculate the activation bias,
73- which is introduced as an Add node behind the MultiTrheshold node.
73+ which is introduced as an Add node behind the MultiThreshold node.
7474 """
7575 raise NotImplementedError ()
7676
@@ -82,7 +82,7 @@ def _calculate_thresholds(self):
8282 @abstractmethod
8383 def _calculate_act_scale (self ):
8484 """Calculate the activation scale,
85- which is indroduced as a Mul node behind the Add node
85+ which is introduced as a Mul node behind the Add node
8686 for the activation bias.
8787 """
8888 raise NotImplementedError ()
@@ -139,6 +139,8 @@ def replace_quant_node(self):
139139 graph .value_info .append (thresh_tensor )
140140 model .set_initializer (thresh_tensor .name , thresholds )
141141
142+ data_layout = model .get_tensor_layout (n .input [0 ])
143+
142144 # Insert MultiThreshold node
143145 outp_trans_node = helper .make_node (
144146 "MultiThreshold" ,
@@ -154,10 +156,15 @@ def replace_quant_node(self):
154156 mt_node = graph .node [running_node_index - 1 ]
155157 mt_inst = getCustomOp (mt_node )
156158
159+ # Inherit the data layout from the input tensor if available
160+ if data_layout is not None :
161+ # Convert list to string representation of the data layout
162+ mt_inst .set_nodeattr ("data_layout" , "" .join (data_layout ))
163+
157164 # Set scale and bias
158165 # If these values are scalar then they can be set as attributes
159166 # of the MultiThreshold node, if not they get inserted as adder and mul nodes
160- # behind the MultiTrheshold nodes.
167+ # behind the MultiThreshold nodes.
161168 bias_scalar = adder_bias .shape == (1 ,) or len (adder_bias .shape ) == 0
162169 scale_scalar = mul_scale .shape == (1 ,) or len (mul_scale .shape ) == 0
163170 if scale_scalar and bias_scalar and self ._q_node .op_type == "BipolarQuant" :
@@ -355,7 +362,7 @@ def _calculate_thresholds(self):
355362 act_node = self ._model .find_direct_predecessors (self ._q_node )
356363 act_node = act_node [0 ]
357364 if act_node .op_type == "Relu" :
358- # Calculate thersholds , see: https://github.com/Xilinx/brevitas/blob/
365+ # Calculate thresholds , see: https://github.com/Xilinx/brevitas/blob/
359366 # a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/export/
360367 # onnx/finn/handler/act.py#L21
361368 num_distinct_values = 2 ** bit_width
@@ -395,8 +402,46 @@ def _calculate_thresholds(self):
395402 else :
396403 thresholds [c ][t ] = step / selu_scale
397404
405+ # Get the shape of the input (should also be the output) tensor
406+ # Note: Querying the input is more safe as we do not want to
407+ # propagate shapes backwards by accident.
408+ shape = self ._model .get_tensor_shape (self ._q_node .input [0 ]) # noqa
409+ # First try to consider the tensor layout of the input for
410+ # determining the number of output channels
411+ layout = self ._model .get_tensor_layout (self ._q_node .input [0 ])
412+ # If there is no layout annotation, guess based on rank of the
413+ # tensor
414+ # TODO: No support for Rank >= 5
415+ if layout is None and len (shape ) < 5 :
416+ # Maps tensor rank to layout annotation
417+ rank_to_layout = {0 : None , 1 : "C" , 2 : "NC" , 3 : "NWC" , 4 : "NCHW" }
418+ # Lookup the layout required by this input shape
419+ layout = rank_to_layout [len (shape )]
420+ # If there is a layout annotation, use this to determine the index
421+ # of the channel dimension
422+ if layout is not None and "C" in layout : # noqa: Duplicate
423+ # Lookup the index in list
424+ cdim = layout .index ("C" )
425+ # If no layout has been annotated or there is no channel dimension, fall
426+ # back to the previous default assumption
427+ else :
428+ # Assume the channels to be in axis 1
429+ cdim = 1
430+ # Issue a warning to the user, so they are aware of this
431+ warnings .warn (
432+ f"No layout annotations for { self ._q_node .input [0 ]} :"
433+ f" Assuming channel dimension at index { cdim } "
434+ )
435+
398436 # ToDo: The index 1 needs to be changed to -1 for the channels last format
399- num_output_channels = self ._model .get_tensor_shape (self ._q_node .output [0 ])[1 ]
437+ num_output_channels = self ._model .get_tensor_shape (self ._q_node .output [0 ])[cdim ]
438+
439+ assert (
440+ thresholds .shape [0 ] == 1 or thresholds .shape [
441+ 0 ] == num_output_channels
442+ ), """Quant node cannot be converted to MultiThreshold because only
443+ per tensor or per channel quantization supported."""
444+
400445 final_shape = (num_output_channels , num_thresholds )
401446 if thresholds .shape != final_shape :
402447 thresholds = np .broadcast_to (thresholds , final_shape )
@@ -417,12 +462,12 @@ def _remove_activation_node(self, multi_threshold_node):
417462 act_node = self ._model .find_direct_predecessors (self ._q_node )
418463 if act_node is None :
419464 raise RuntimeError (
420- "For handling of Relu activations a predecesor to " "the Quant node must exist."
465+ "For handling of Relu activations a predecessor to " "the Quant node must exist."
421466 )
422467 act_node = act_node [0 ]
423468 if act_node .op_type not in self .valid_predecessor_op_types ():
424469 raise RuntimeError (
425- "The predecesor of the Quant node must be Relu or Selu for handling "
470+ "The predecessor of the Quant node must be Relu or Selu for handling "
426471 "of activations."
427472 )
428473
@@ -509,7 +554,7 @@ def _calculate_thresholds(self):
509554 else :
510555 raise RuntimeError ("Got an unexpected quantizer node type" )
511556
512- # Calculate thersholds , see: https://github.com/Xilinx/brevitas/
557+ # Calculate thresholds , see: https://github.com/Xilinx/brevitas/
513558 # blob/a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/
514559 # export/onnx/finn/handler/act.py#L76
515560 if bit_width == 1.0 :
@@ -537,13 +582,49 @@ def _calculate_thresholds(self):
537582 for t in range (num_thresholds ):
538583 thresholds [c ][t ] = min_threshold [c ] + step [c ] * t
539584
540- # currently only per tensor or per channel quantization is supported
541- num_output_channels = self ._model .get_tensor_shape (self ._q_node .output [0 ])[1 ]
585+ # Get the shape of the input (should also be the output) tensor
586+ # Note: Querying the input is more safe as we do not want to
587+ # propagate shapes backwards by accident.
588+ shape = self ._model .get_tensor_shape (self ._q_node .input [0 ])
589+ # First try to consider the tensor layout of the input for
590+ # determining the number of output channels
591+ layout = self ._model .get_tensor_layout (self ._q_node .input [0 ]) # noqa
592+ # If there is no layout annotation, guess based on rank of the
593+ # tensor
594+ # TODO: No support for Rank >= 5
595+ if layout is None and len (shape ) < 5 :
596+ # Maps tensor rank to layout annotation
597+ rank_to_layout = {0 : None , 1 : "C" , 2 : "NC" , 3 : "NWC" , 4 : "NCHW" }
598+ # Lookup the layout required by this input shape
599+ layout = rank_to_layout [len (shape )]
600+ # If there is a layout annotation, use this to determine the index
601+ # of the channel dimension
602+ if layout is not None and "C" in layout : # noqa: Duplicate
603+ # Lookup the index in list
604+ cdim = layout .index ("C" )
605+ # If no layout has been annotated or there is no channel dimension,
606+ # fall back to the previous default assumption
607+ else :
608+ # Assume the channels to be in axis 1
609+ cdim = 1
610+ # Issue a warning to the user, so they are aware of this
611+ warnings .warn (
612+ f"No layout annotations for { self ._q_node .input [0 ]} :"
613+ f" Assuming channel dimension at index { cdim } "
614+ )
615+
616+ # ToDo: The index 1 needs to be changed to -1 for the channels last format
617+ num_output_channels = self ._model .get_tensor_shape (self ._q_node .output [0 ])[cdim ]
618+
542619 assert (
543620 thresholds .shape [0 ] == 1 or thresholds .shape [0 ] == num_output_channels
544621 ), """Quant node cannot be converted to MultiThreshold because only
545622 per tensor or per channel quantization supported."""
546623
624+ final_shape = (num_output_channels , num_thresholds )
625+ if thresholds .shape != final_shape :
626+ thresholds = np .broadcast_to (thresholds , final_shape )
627+
547628 return thresholds
548629
549630 def _calculate_act_scale (self ):
0 commit comments