3333from qonnx .custom_op .base import CustomOp
3434
3535
36- def multithreshold (v , thresholds , out_scale = None , out_bias = None ):
36+ def multithreshold (v , thresholds , out_scale = None , out_bias = None , channels_last = False ):
3737 """Given a set of threshold values t={t_0, t_1 ... t_n} the successive
3838 thresholding maps any real number x to an integer in the interval [0, n],
3939 where the returned integer is the number of thresholds x is greater than
4040 or equal to.
4141
4242 The output tensor will be scaled by out_scale and biased by out_bias."""
43- # the inputs are expected to be in the shape (N,C,H,W) or (N, C)
44- # the MultiThreshold node supports a data_layout attribute that can be set
45- # to 'NHWC' to support (N,H,W,C) data layout mode for in-out as well
46- # N : Batch size
47- # C : Number of channels
48- # H : Heigth of the input images
49- # W : Width of the input images
50- #
51- # the thresholds are expected to be in the shape (C, B)
43+ # if channels_last=False:
44+ # the inputs are expected to be in the shape (N,C,_) where:
45+ # C is the number of channels
46+ # _ represents any (including zero) number of spatial dims
47+ # if channels_last=True, expected input shape is (N,_,C)
48+ # the thresholds are expected to be in the shape (C, B) where
5249 # C : Number of channels (must be the same value as C in input tensor
5350 # or 1 if all channels use the same threshold value)
5451 # B : Desired activation steps => i.e. for 4-bit activation,
55- # B=7 (2^(n)-1 and n=4)
52+ # B=7 (2^(n)-1 and n=4), but can also be fewer
5653 # the output tensor will be scaled by out_scale and biased by out_bias
57- # assert threshold shape
54+ # assert threshold shape - threshold channels must be either equal
55+ # to input channels, or be a single global scalar
5856 is_global_threshold = thresholds .shape [0 ] == 1
57+ channel_axis = - 1 if channels_last else 1
5958 assert (
60- v .shape [1 ] == thresholds .shape [0 ]
59+ v .shape [channel_axis ] == thresholds .shape [0 ]
6160 ) or is_global_threshold , """"Threshold
6261 shape incorrect"""
63- # save the required shape sizes for the loops (N, C and B)
64- num_batch = v .shape [0 ]
65- num_channel = v .shape [1 ]
66- num_act = thresholds .shape [1 ]
67- # reshape inputs to enable channel-wise reading
68- vr = v .reshape ((v .shape [0 ], v .shape [1 ], - 1 ))
69- # initiate output tensor
70- ret = np .zeros_like (vr )
71- # iterate over thresholds channel-wise
72- for t in range (num_channel ):
73- channel_thresh = thresholds [0 ] if is_global_threshold else thresholds [t ]
74- # iterate over batches
75- for b in range (num_batch ):
76- # iterate over the different thresholds for one channel
77- for a in range (num_act ):
78- ret [b ][t ] += (vr [b ][t ] >= channel_thresh [a ]).astype (int )
62+ if not channels_last :
63+ # starting assumption: input tensor is in NC_ layout
64+ # (where _ can be any number of spatial dims)
65+ # get the input tensor into right shape to use numpy broadcasting
66+ # move the channels axis to the last position
67+ # NC_ -> N_C
68+ vm = np .moveaxis (v , source = 1 , destination = - 1 )
69+ else :
70+ vm = v
71+ # add a dummy dimension at the end of the input tensor
72+ # (for broadcasting against the thresholds)
73+ # N_C -> N_C1
74+ vm = np .expand_dims (vm , axis = - 1 )
75+ # now perform the comparison against thresholds
76+ # (N_C1 >= CT) -> N_CT
77+ cmp = vm >= thresholds
78+ # replace last axis by count of nonzero values (True)
79+ # N_CT -> N_C
80+ # note the .astype cast to ensure type remains the same
81+ # TODO enforce ints instead?
82+ ret = np .count_nonzero (cmp , axis = - 1 ).astype (v .dtype )
83+ if not channels_last :
84+ # move the channels axis back to index 1
85+ ret = np .moveaxis (ret , source = - 1 , destination = 1 )
86+ assert ret .shape == v .shape , "Shape changed during thresholding!"
7987
8088 if out_scale is None :
8189 out_scale = 1.0
8290 if out_bias is None :
8391 out_bias = 0.0
84- return out_scale * ret . reshape ( v . shape ) + out_bias
92+ return out_scale * ret + out_bias
8593
8694
8795class MultiThreshold (CustomOp ):
@@ -92,7 +100,7 @@ def get_nodeattr_types(self):
92100 "out_dtype" : ("s" , True , "" ),
93101 "out_scale" : ("f" , False , 1.0 ),
94102 "out_bias" : ("f" , False , 0.0 ),
95- "data_layout" : ("s" , False , "NCHW" , { "NCHW" , "NHWC" } ),
103+ "data_layout" : ("s" , False , "NCHW" ),
96104 }
97105
98106 def make_shape_compatible_op (self , model ):
@@ -124,51 +132,11 @@ def execute_node(self, context, graph):
124132 out_bias = self .get_nodeattr ("out_bias" )
125133 # transpose input if NHWC data layout is chosen
126134 data_layout = self .get_nodeattr ("data_layout" )
127- if data_layout == "NHWC" :
128- if v .ndim == 4 :
129- # NHWC -> NCHW
130- v = np .transpose (v , (0 , 3 , 1 , 2 ))
131- elif v .ndim == 2 :
132- # no HW dimension means NHWC and NCHW layouts are equivalent
133- pass
134- else :
135- raise Exception ("Unknown data_layout and input ndim" " combination for MultiThreshold." )
136-
137- # Remember whether the shape has been modified to handle 1d or 3d data
138- # layouts
139- orig_shape = None
140- # If the input tensor has dimensions not covered by the NC or NCWH data
141- # layouts, the shape needs to be adapted such that it can be handled by
142- # multithreshold.
143- # TODO: Seems like a rather sketchy solution to support arbitrary data
144- # layouts. This does not even validate the assumption of channel last
145- # layout.
146- if v .ndim not in {2 , 4 }:
147- # Remember the original shape to be restored later
148- orig_shape = v .shape
149- # Assume last dimension to be the channel dimension C and reshape
150- # into NC layout which is supported by multithreshold
151- v = v .reshape ((- 1 , v .shape [- 1 ]))
152-
135+ channels_last = True if data_layout [- 1 ] == "C" else False
153136 # calculate output
154- output = multithreshold (v , thresholds , out_scale , out_bias )
155- # setting context according to output
156- if data_layout == "NHWC" :
157- if output .ndim == 4 :
158- # NCHW -> NHWC
159- output = np .transpose (output , (0 , 2 , 3 , 1 ))
160- elif output .ndim == 2 :
161- # no HW dimension means NHWC and NCHW layouts are equivalent
162- pass
163- else :
164- raise Exception ("Unknown data_layout and output ndim" " combination for MultiThreshold." )
165-
166- # If the shape has been modified to support arbitrary layouts, restore
167- # the original shape
168- # TODO: Part of the rather sketchy solution above.
169- if orig_shape is not None :
170- output = output .reshape (orig_shape )
171-
137+ orig_shape = v .shape
138+ output = multithreshold (v , thresholds , out_scale , out_bias , channels_last )
139+ assert output .shape == orig_shape , "Shape changed during thresholding!"
172140 context [node .output [0 ]] = output
173141
174142 def verify_node (self ):
0 commit comments