@@ -40,48 +40,47 @@ def multithreshold(v, thresholds, out_scale=None, out_bias=None):
4040 or equal to.
4141
4242 The output tensor will be scaled by out_scale and biased by out_bias."""
43- # the inputs are expected to be in the shape (N,C,H,W) or (N, C)
44- # the MultiThreshold node supports a data_layout attribute that can be set
45- # to 'NHWC' to support (N,H,W,C) data layout mode for in-out as well
46- # N : Batch size
47- # C : Number of channels
48- # H : Heigth of the input images
49- # W : Width of the input images
50- #
51- # the thresholds are expected to be in the shape (C, B)
43+ # the inputs are expected to be in the shape (N,C,_) where:
44+ # C is the number of channels
45+ # _ represents any (including zero) number of spatial dims
46+ # the thresholds are expected to be in the shape (C, B) where
5247 # C : Number of channels (must be the same value as C in input tensor
5348 # or 1 if all channels use the same threshold value)
5449 # B : Desired activation steps => i.e. for 4-bit activation,
55- # B=7 (2^(n)-1 and n=4)
50+ # B=7 (2^(n)-1 and n=4), but can also be fewer
5651 # the output tensor will be scaled by out_scale and biased by out_bias
57- # assert threshold shape
52+ # assert threshold shape - threshold channels must be either equal
53+ # to input channels, or be a single global scalar
5854 is_global_threshold = thresholds .shape [0 ] == 1
5955 assert (
6056 v .shape [1 ] == thresholds .shape [0 ]
6157 ) or is_global_threshold , """"Threshold
6258 shape incorrect"""
63- # save the required shape sizes for the loops (N, C and B)
64- num_batch = v .shape [0 ]
65- num_channel = v .shape [1 ]
66- num_act = thresholds .shape [1 ]
67- # reshape inputs to enable channel-wise reading
68- vr = v .reshape ((v .shape [0 ], v .shape [1 ], - 1 ))
69- # initiate output tensor
70- ret = np .zeros_like (vr )
71- # iterate over thresholds channel-wise
72- for t in range (num_channel ):
73- channel_thresh = thresholds [0 ] if is_global_threshold else thresholds [t ]
74- # iterate over batches
75- for b in range (num_batch ):
76- # iterate over the different thresholds for one channel
77- for a in range (num_act ):
78- ret [b ][t ] += (vr [b ][t ] >= channel_thresh [a ]).astype (int )
59+ # starting assumption: input tensor is in NC_ layout
60+ # (where _ can be any number of spatial dims)
61+ # get the input tensor into right shape to use numpy broadcasting
62+ # move the channels axis to the last position
63+ # NC_ -> N_C
64+ vm = np .moveaxis (v , source = 1 , destination = - 1 )
65+ # add a dummy dimension at the end of the input tensor
66+ # (for broadcasting against the thresholds)
67+ # N_C -> N_C1
68+ vm = np .expand_dims (vm , axis = - 1 )
69+ # now perform the comparison against thresholds
70+ # (N_C1 >= CT) -> N_CT
71+ cmp = vm >= thresholds
72+ # replace last axis by count of nonzero values (True)
73+ # N_CT -> N_C
74+ ret = np .count_nonzero (cmp , axis = - 1 )
75+ # finally, move the channels axis back to index 1
76+ ret = np .moveaxis (ret , source = - 1 , destination = 1 )
77+ assert ret .shape == v .shape , "Shape changed during thresholding!"
7978
8079 if out_scale is None :
8180 out_scale = 1.0
8281 if out_bias is None :
8382 out_bias = 0.0
84- return out_scale * ret . reshape ( v . shape ) + out_bias
83+ return out_scale * ret + out_bias
8584
8685
8786class MultiThreshold (CustomOp ):
0 commit comments