Skip to content

Commit 82f38a2

Browse files
author
Yaman Umuroglu
authored
Merge pull request #201 from fastmachinelearning/feature/threshold_revamp
Revamp MultiThreshold
2 parents 33f2670 + b0265fb commit 82f38a2

3 files changed

Lines changed: 292 additions & 268 deletions

File tree

src/qonnx/custom_op/general/multithreshold.py

Lines changed: 43 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -33,55 +33,63 @@
3333
from qonnx.custom_op.base import CustomOp
3434

3535

36-
def multithreshold(v, thresholds, out_scale=None, out_bias=None):
36+
def multithreshold(v, thresholds, out_scale=None, out_bias=None, channels_last=False):
3737
"""Given a set of threshold values t={t_0, t_1 ... t_n} the successive
3838
thresholding maps any real number x to an integer in the interval [0, n],
3939
where the returned integer is the number of thresholds x is greater than
4040
or equal to.
4141
4242
The output tensor will be scaled by out_scale and biased by out_bias."""
43-
# the inputs are expected to be in the shape (N,C,H,W) or (N, C)
44-
# the MultiThreshold node supports a data_layout attribute that can be set
45-
# to 'NHWC' to support (N,H,W,C) data layout mode for in-out as well
46-
# N : Batch size
47-
# C : Number of channels
48-
# H : Heigth of the input images
49-
# W : Width of the input images
50-
#
51-
# the thresholds are expected to be in the shape (C, B)
43+
# if channels_last=False:
44+
# the inputs are expected to be in the shape (N,C,_) where:
45+
# C is the number of channels
46+
# _ represents any (including zero) number of spatial dims
47+
# if channels_last=True, expected input shape is (N,_,C)
48+
# the thresholds are expected to be in the shape (C, B) where
5249
# C : Number of channels (must be the same value as C in input tensor
5350
# or 1 if all channels use the same threshold value)
5451
# B : Desired activation steps => i.e. for 4-bit activation,
55-
# B=7 (2^(n)-1 and n=4)
52+
# B=7 (2^(n)-1 and n=4), but can also be fewer
5653
# the output tensor will be scaled by out_scale and biased by out_bias
57-
# assert threshold shape
54+
# assert threshold shape - threshold channels must be either equal
55+
# to input channels, or be a single global scalar
5856
is_global_threshold = thresholds.shape[0] == 1
57+
channel_axis = -1 if channels_last else 1
5958
assert (
60-
v.shape[1] == thresholds.shape[0]
59+
v.shape[channel_axis] == thresholds.shape[0]
6160
) or is_global_threshold, """"Threshold
6261
shape incorrect"""
63-
# save the required shape sizes for the loops (N, C and B)
64-
num_batch = v.shape[0]
65-
num_channel = v.shape[1]
66-
num_act = thresholds.shape[1]
67-
# reshape inputs to enable channel-wise reading
68-
vr = v.reshape((v.shape[0], v.shape[1], -1))
69-
# initiate output tensor
70-
ret = np.zeros_like(vr)
71-
# iterate over thresholds channel-wise
72-
for t in range(num_channel):
73-
channel_thresh = thresholds[0] if is_global_threshold else thresholds[t]
74-
# iterate over batches
75-
for b in range(num_batch):
76-
# iterate over the different thresholds for one channel
77-
for a in range(num_act):
78-
ret[b][t] += (vr[b][t] >= channel_thresh[a]).astype(int)
62+
if not channels_last:
63+
# starting assumption: input tensor is in NC_ layout
64+
# (where _ can be any number of spatial dims)
65+
# get the input tensor into right shape to use numpy broadcasting
66+
# move the channels axis to the last position
67+
# NC_ -> N_C
68+
vm = np.moveaxis(v, source=1, destination=-1)
69+
else:
70+
vm = v
71+
# add a dummy dimension at the end of the input tensor
72+
# (for broadcasting against the thresholds)
73+
# N_C -> N_C1
74+
vm = np.expand_dims(vm, axis=-1)
75+
# now perform the comparison against thresholds
76+
# (N_C1 >= CT) -> N_CT
77+
cmp = vm >= thresholds
78+
# replace last axis by count of nonzero values (True)
79+
# N_CT -> N_C
80+
# note the .astype cast to ensure type remains the same
81+
# TODO enforce ints instead?
82+
ret = np.count_nonzero(cmp, axis=-1).astype(v.dtype)
83+
if not channels_last:
84+
# move the channels axis back to index 1
85+
ret = np.moveaxis(ret, source=-1, destination=1)
86+
assert ret.shape == v.shape, "Shape changed during thresholding!"
7987

8088
if out_scale is None:
8189
out_scale = 1.0
8290
if out_bias is None:
8391
out_bias = 0.0
84-
return out_scale * ret.reshape(v.shape) + out_bias
92+
return out_scale * ret + out_bias
8593

8694

8795
class MultiThreshold(CustomOp):
@@ -92,7 +100,7 @@ def get_nodeattr_types(self):
92100
"out_dtype": ("s", True, ""),
93101
"out_scale": ("f", False, 1.0),
94102
"out_bias": ("f", False, 0.0),
95-
"data_layout": ("s", False, "NCHW", {"NCHW", "NHWC"}),
103+
"data_layout": ("s", False, "NCHW"),
96104
}
97105

98106
def make_shape_compatible_op(self, model):
@@ -124,51 +132,11 @@ def execute_node(self, context, graph):
124132
out_bias = self.get_nodeattr("out_bias")
125133
# transpose input if NHWC data layout is chosen
126134
data_layout = self.get_nodeattr("data_layout")
127-
if data_layout == "NHWC":
128-
if v.ndim == 4:
129-
# NHWC -> NCHW
130-
v = np.transpose(v, (0, 3, 1, 2))
131-
elif v.ndim == 2:
132-
# no HW dimension means NHWC and NCHW layouts are equivalent
133-
pass
134-
else:
135-
raise Exception("Unknown data_layout and input ndim" " combination for MultiThreshold.")
136-
137-
# Remember whether the shape has been modified to handle 1d or 3d data
138-
# layouts
139-
orig_shape = None
140-
# If the input tensor has dimensions not covered by the NC or NCWH data
141-
# layouts, the shape needs to be adapted such that it can be handled by
142-
# multithreshold.
143-
# TODO: Seems like a rather sketchy solution to support arbitrary data
144-
# layouts. This does not even validate the assumption of channel last
145-
# layout.
146-
if v.ndim not in {2, 4}:
147-
# Remember the original shape to be restored later
148-
orig_shape = v.shape
149-
# Assume last dimension to be the channel dimension C and reshape
150-
# into NC layout which is supported by multithreshold
151-
v = v.reshape((-1, v.shape[-1]))
152-
135+
channels_last = True if data_layout[-1] == "C" else False
153136
# calculate output
154-
output = multithreshold(v, thresholds, out_scale, out_bias)
155-
# setting context according to output
156-
if data_layout == "NHWC":
157-
if output.ndim == 4:
158-
# NCHW -> NHWC
159-
output = np.transpose(output, (0, 2, 3, 1))
160-
elif output.ndim == 2:
161-
# no HW dimension means NHWC and NCHW layouts are equivalent
162-
pass
163-
else:
164-
raise Exception("Unknown data_layout and output ndim" " combination for MultiThreshold.")
165-
166-
# If the shape has been modified to support arbitrary layouts, restore
167-
# the original shape
168-
# TODO: Part of the rather sketchy solution above.
169-
if orig_shape is not None:
170-
output = output.reshape(orig_shape)
171-
137+
orig_shape = v.shape
138+
output = multithreshold(v, thresholds, out_scale, out_bias, channels_last)
139+
assert output.shape == orig_shape, "Shape changed during thresholding!"
172140
context[node.output[0]] = output
173141

174142
def verify_node(self):

tests/core/test_custom_onnx_exec.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,9 @@ def test_execute_custom_node_multithreshold():
274274
assert (execution_context["out"] == outputs_nhwc).all()
275275
# check the set of allowed values
276276
op_inst = getCustomOp(node_def)
277-
assert op_inst.get_nodeattr_allowed_values("data_layout") == {"NCHW", "NHWC"}
277+
# TODO: Removed this check to generalize the supported data layouts, but do
278+
# we need some other check to verify the validity of data layouts?
279+
# assert op_inst.get_nodeattr_allowed_values("data_layout") == {"NCHW", "NHWC", "NC", "NWC", "NCW"}
278280
# exercise the allowed value checks
279281
# try to set attribute to non-allowed value, should raise an exception
280282
try:

0 commit comments

Comments
 (0)