@@ -402,12 +402,24 @@ def _calculate_thresholds(self):
402402 else :
403403 thresholds [c ][t ] = step / selu_scale
404404
405- # First try to consider the tensor layout of the input for determining
406- # the number of output channels
405+ # Get the shape of the input (should also be the output) tensor
406+ # Note: Querying the input is more safe as we do not want to
407+ # propagate shapes backwards by accident.
408+ shape = self ._model .get_tensor_shape (self ._q_node .input [0 ]) # noqa
409+ # First try to consider the tensor layout of the input for
410+ # determining the number of output channels
407411 layout = self ._model .get_tensor_layout (self ._q_node .input [0 ])
408- # If there is a layout annotation, use this to determine the index of
409- # the channel dimension
410- if layout is not None and "C" in layout :
412+ # If there is no layout annotation, guess based on rank of the
413+ # tensor
414+ # TODO: No support for Rank >= 5
415+ if layout is None and len (shape ) < 5 :
416+ # Maps tensor rank to layout annotation
417+ rank_to_layout = {0 : None , 1 : "C" , 2 : "NC" , 3 : "NWC" , 4 : "NCHW" }
418+ # Lookup the layout required by this input shape
419+ layout = rank_to_layout [len (shape )]
420+ # If there is a layout annotation, use this to determine the index
421+ # of the channel dimension
422+ if layout is not None and "C" in layout : # noqa: Duplicate
411423 # Lookup the index in list
412424 cdim = layout .index ("C" )
413425 # If no layout has been annotated or there is no channel dimension, fall
@@ -570,12 +582,24 @@ def _calculate_thresholds(self):
570582 for t in range (num_thresholds ):
571583 thresholds [c ][t ] = min_threshold [c ] + step [c ] * t
572584
585+ # Get the shape of the input (should also be the output) tensor
586+ # Note: Querying the input is more safe as we do not want to
587+ # propagate shapes backwards by accident.
588+ shape = self ._model .get_tensor_shape (self ._q_node .input [0 ])
573589 # First try to consider the tensor layout of the input for
574590 # determining the number of output channels
575- layout = self ._model .get_tensor_layout (self ._q_node .input [0 ])
591+ layout = self ._model .get_tensor_layout (self ._q_node .input [0 ]) # noqa
592+ # If there is no layout annotation, guess based on rank of the
593+ # tensor
594+ # TODO: No support for Rank >= 5
595+ if layout is None and len (shape ) < 5 :
596+ # Maps tensor rank to layout annotation
597+ rank_to_layout = {0 : None , 1 : "C" , 2 : "NC" , 3 : "NWC" , 4 : "NCHW" }
598+ # Lookup the layout required by this input shape
599+ layout = rank_to_layout [len (shape )]
576600 # If there is a layout annotation, use this to determine the index
577601 # of the channel dimension
578- if layout is not None and "C" in layout :
602+ if layout is not None and "C" in layout : # noqa: Duplicate
579603 # Lookup the index in list
580604 cdim = layout .index ("C" )
581605 # If no layout has been annotated or there is no channel dimension,
0 commit comments