3333from qonnx .custom_op .base import CustomOp
3434
3535
36- def multithreshold (v , thresholds , out_scale = None , out_bias = None ):
36+ def multithreshold (v , thresholds , out_scale = None , out_bias = None , channels_last = False ):
3737 """Given a set of threshold values t={t_0, t_1 ... t_n} the successive
3838 thresholding maps any real number x to an integer in the interval [0, n],
3939 where the returned integer is the number of thresholds x is greater than
4040 or equal to.
4141
4242 The output tensor will be scaled by out_scale and biased by out_bias."""
43+ # if channels_last=False:
4344 # the inputs are expected to be in the shape (N,C,_) where:
4445 # C is the number of channels
4546 # _ represents any (including zero) number of spatial dims
47+ # if channels_last=True, expected input shape is (N,_,C)
4648 # the thresholds are expected to be in the shape (C, B) where
4749 # C : Number of channels (must be the same value as C in input tensor
4850 # or 1 if all channels use the same threshold value)
@@ -52,16 +54,20 @@ def multithreshold(v, thresholds, out_scale=None, out_bias=None):
5254 # assert threshold shape - threshold channels must be either equal
5355 # to input channels, or be a single global scalar
5456 is_global_threshold = thresholds .shape [0 ] == 1
57+ channel_axis = - 1 if channels_last else 1
5558 assert (
56- v .shape [1 ] == thresholds .shape [0 ]
59+ v .shape [channel_axis ] == thresholds .shape [0 ]
5760 ) or is_global_threshold , """"Threshold
5861 shape incorrect"""
59- # starting assumption: input tensor is in NC_ layout
60- # (where _ can be any number of spatial dims)
61- # get the input tensor into right shape to use numpy broadcasting
62- # move the channels axis to the last position
63- # NC_ -> N_C
64- vm = np .moveaxis (v , source = 1 , destination = - 1 )
62+ if not channels_last :
63+ # starting assumption: input tensor is in NC_ layout
64+ # (where _ can be any number of spatial dims)
65+ # get the input tensor into right shape to use numpy broadcasting
66+ # move the channels axis to the last position
67+ # NC_ -> N_C
68+ vm = np .moveaxis (v , source = 1 , destination = - 1 )
69+ else :
70+ vm = v
6571 # add a dummy dimension at the end of the input tensor
6672 # (for broadcasting against the thresholds)
6773 # N_C -> N_C1
@@ -72,8 +78,9 @@ def multithreshold(v, thresholds, out_scale=None, out_bias=None):
7278 # replace last axis by count of nonzero values (True)
7379 # N_CT -> N_C
7480 ret = np .count_nonzero (cmp , axis = - 1 )
75- # finally, move the channels axis back to index 1
76- ret = np .moveaxis (ret , source = - 1 , destination = 1 )
81+ if not channels_last :
82+ # move the channels axis back to index 1
83+ ret = np .moveaxis (ret , source = - 1 , destination = 1 )
7784 assert ret .shape == v .shape , "Shape changed during thresholding!"
7885
7986 if out_scale is None :
@@ -123,51 +130,11 @@ def execute_node(self, context, graph):
123130 out_bias = self .get_nodeattr ("out_bias" )
124131 # transpose input if NHWC data layout is chosen
125132 data_layout = self .get_nodeattr ("data_layout" )
126- if data_layout == "NHWC" :
127- if v .ndim == 4 :
128- # NHWC -> NCHW
129- v = np .transpose (v , (0 , 3 , 1 , 2 ))
130- elif v .ndim == 2 :
131- # no HW dimension means NHWC and NCHW layouts are equivalent
132- pass
133- else :
134- raise Exception ("Unknown data_layout and input ndim" " combination for MultiThreshold." )
135-
136- # Remember whether the shape has been modified to handle 1d or 3d data
137- # layouts
138- orig_shape = None
139- # If the input tensor has dimensions not covered by the NC or NCWH data
140- # layouts, the shape needs to be adapted such that it can be handled by
141- # multithreshold.
142- # TODO: Seems like a rather sketchy solution to support arbitrary data
143- # layouts. This does not even validate the assumption of channel last
144- # layout.
145- if v .ndim not in {2 , 4 }:
146- # Remember the original shape to be restored later
147- orig_shape = v .shape
148- # Assume last dimension to be the channel dimension C and reshape
149- # into NC layout which is supported by multithreshold
150- v = v .reshape ((- 1 , v .shape [- 1 ]))
151-
133+ channels_last = True if data_layout [- 1 ] == "C" else False
152134 # calculate output
153- output = multithreshold (v , thresholds , out_scale , out_bias )
154- # setting context according to output
155- if data_layout == "NHWC" :
156- if output .ndim == 4 :
157- # NCHW -> NHWC
158- output = np .transpose (output , (0 , 2 , 3 , 1 ))
159- elif output .ndim == 2 :
160- # no HW dimension means NHWC and NCHW layouts are equivalent
161- pass
162- else :
163- raise Exception ("Unknown data_layout and output ndim" " combination for MultiThreshold." )
164-
165- # If the shape has been modified to support arbitrary layouts, restore
166- # the original shape
167- # TODO: Part of the rather sketchy solution above.
168- if orig_shape is not None :
169- output = output .reshape (orig_shape )
170-
135+ orig_shape = v .shape
136+ output = multithreshold (v , thresholds , out_scale , out_bias , channels_last )
137+ assert output .shape == orig_shape , "Shape changed during thresholding!"
171138 context [node .output [0 ]] = output
172139
173140 def verify_node (self ):
0 commit comments