Skip to content

Commit 0ed26fc

Browse files
author
Yaman Umuroglu
committed
[MultiThreshold] overhaul exec by using only npy vectorized ops
1 parent 7f331f2 commit 0ed26fc

1 file changed

Lines changed: 27 additions & 28 deletions

File tree

src/qonnx/custom_op/general/multithreshold.py

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -40,48 +40,47 @@ def multithreshold(v, thresholds, out_scale=None, out_bias=None):
4040
or equal to.
4141
4242
The output tensor will be scaled by out_scale and biased by out_bias."""
43-
# the inputs are expected to be in the shape (N,C,H,W) or (N, C)
44-
# the MultiThreshold node supports a data_layout attribute that can be set
45-
# to 'NHWC' to support (N,H,W,C) data layout mode for in-out as well
46-
# N : Batch size
47-
# C : Number of channels
48-
# H : Heigth of the input images
49-
# W : Width of the input images
50-
#
51-
# the thresholds are expected to be in the shape (C, B)
43+
# the inputs are expected to be in the shape (N,C,_) where:
44+
# C is the number of channels
45+
# _ represents any (including zero) number of spatial dims
46+
# the thresholds are expected to be in the shape (C, B) where
5247
# C : Number of channels (must be the same value as C in input tensor
5348
# or 1 if all channels use the same threshold value)
5449
# B : Desired activation steps => i.e. for 4-bit activation,
55-
# B=7 (2^(n)-1 and n=4)
50+
# B=7 (2^(n)-1 and n=4), but can also be fewer
5651
# the output tensor will be scaled by out_scale and biased by out_bias
57-
# assert threshold shape
52+
# assert threshold shape - threshold channels must be either equal
53+
# to input channels, or be a single global scalar
5854
is_global_threshold = thresholds.shape[0] == 1
5955
assert (
6056
v.shape[1] == thresholds.shape[0]
6157
) or is_global_threshold, """"Threshold
6258
shape incorrect"""
63-
# save the required shape sizes for the loops (N, C and B)
64-
num_batch = v.shape[0]
65-
num_channel = v.shape[1]
66-
num_act = thresholds.shape[1]
67-
# reshape inputs to enable channel-wise reading
68-
vr = v.reshape((v.shape[0], v.shape[1], -1))
69-
# initiate output tensor
70-
ret = np.zeros_like(vr)
71-
# iterate over thresholds channel-wise
72-
for t in range(num_channel):
73-
channel_thresh = thresholds[0] if is_global_threshold else thresholds[t]
74-
# iterate over batches
75-
for b in range(num_batch):
76-
# iterate over the different thresholds for one channel
77-
for a in range(num_act):
78-
ret[b][t] += (vr[b][t] >= channel_thresh[a]).astype(int)
59+
# starting assumption: input tensor is in NC_ layout
60+
# (where _ can be any number of spatial dims)
61+
# get the input tensor into right shape to use numpy broadcasting
62+
# move the channels axis to the last position
63+
# NC_ -> N_C
64+
vm = np.moveaxis(v, source=1, destination=-1)
65+
# add a dummy dimension at the end of the input tensor
66+
# (for broadcasting against the thresholds)
67+
# N_C -> N_C1
68+
vm = np.expand_dims(vm, axis=-1)
69+
# now perform the comparison against thresholds
70+
# (N_C1 >= CT) -> N_CT
71+
cmp = vm >= thresholds
72+
# replace last axis by count of nonzero values (True)
73+
# N_CT -> N_C
74+
ret = np.count_nonzero(cmp, axis=-1)
75+
# finally, move the channels axis back to index 1
76+
ret = np.moveaxis(ret, source=-1, destination=1)
77+
assert ret.shape == v.shape, "Shape changed during thresholding!"
7978

8079
if out_scale is None:
8180
out_scale = 1.0
8281
if out_bias is None:
8382
out_bias = 0.0
84-
return out_scale * ret.reshape(v.shape) + out_bias
83+
return out_scale * ret + out_bias
8584

8685

8786
class MultiThreshold(CustomOp):

0 commit comments

Comments
 (0)