Skip to content

Commit a142fc5

Browse files
author
Yaman Umuroglu
committed
[MultiThreshold] clean up layout handling in MT custom op
should now support arbitrary layouts, but remaining backwards compatible with the data_layout attribute may be a problem
1 parent 920e253 commit a142fc5

1 file changed

Lines changed: 21 additions & 54 deletions

File tree

src/qonnx/custom_op/general/multithreshold.py

Lines changed: 21 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,18 @@
3333
from qonnx.custom_op.base import CustomOp
3434

3535

36-
def multithreshold(v, thresholds, out_scale=None, out_bias=None):
36+
def multithreshold(v, thresholds, out_scale=None, out_bias=None, channels_last=False):
3737
"""Given a set of threshold values t={t_0, t_1 ... t_n} the successive
3838
thresholding maps any real number x to an integer in the interval [0, n],
3939
where the returned integer is the number of thresholds x is greater than
4040
or equal to.
4141
4242
The output tensor will be scaled by out_scale and biased by out_bias."""
43+
# if channels_last=False:
4344
# the inputs are expected to be in the shape (N,C,_) where:
4445
# C is the number of channels
4546
# _ represents any (including zero) number of spatial dims
47+
# if channels_last=True, expected input shape is (N,_,C)
4648
# the thresholds are expected to be in the shape (C, B) where
4749
# C : Number of channels (must be the same value as C in input tensor
4850
# or 1 if all channels use the same threshold value)
@@ -52,16 +54,20 @@ def multithreshold(v, thresholds, out_scale=None, out_bias=None):
5254
# assert threshold shape - threshold channels must be either equal
5355
# to input channels, or be a single global scalar
5456
is_global_threshold = thresholds.shape[0] == 1
57+
channel_axis = -1 if channels_last else 1
5558
assert (
56-
v.shape[1] == thresholds.shape[0]
59+
v.shape[channel_axis] == thresholds.shape[0]
5760
) or is_global_threshold, """"Threshold
5861
shape incorrect"""
59-
# starting assumption: input tensor is in NC_ layout
60-
# (where _ can be any number of spatial dims)
61-
# get the input tensor into right shape to use numpy broadcasting
62-
# move the channels axis to the last position
63-
# NC_ -> N_C
64-
vm = np.moveaxis(v, source=1, destination=-1)
62+
if not channels_last:
63+
# starting assumption: input tensor is in NC_ layout
64+
# (where _ can be any number of spatial dims)
65+
# get the input tensor into right shape to use numpy broadcasting
66+
# move the channels axis to the last position
67+
# NC_ -> N_C
68+
vm = np.moveaxis(v, source=1, destination=-1)
69+
else:
70+
vm = v
6571
# add a dummy dimension at the end of the input tensor
6672
# (for broadcasting against the thresholds)
6773
# N_C -> N_C1
@@ -72,8 +78,9 @@ def multithreshold(v, thresholds, out_scale=None, out_bias=None):
7278
# replace last axis by count of nonzero values (True)
7379
# N_CT -> N_C
7480
ret = np.count_nonzero(cmp, axis=-1)
75-
# finally, move the channels axis back to index 1
76-
ret = np.moveaxis(ret, source=-1, destination=1)
81+
if not channels_last:
82+
# move the channels axis back to index 1
83+
ret = np.moveaxis(ret, source=-1, destination=1)
7784
assert ret.shape == v.shape, "Shape changed during thresholding!"
7885

7986
if out_scale is None:
@@ -123,51 +130,11 @@ def execute_node(self, context, graph):
123130
out_bias = self.get_nodeattr("out_bias")
124131
# transpose input if NHWC data layout is chosen
125132
data_layout = self.get_nodeattr("data_layout")
126-
if data_layout == "NHWC":
127-
if v.ndim == 4:
128-
# NHWC -> NCHW
129-
v = np.transpose(v, (0, 3, 1, 2))
130-
elif v.ndim == 2:
131-
# no HW dimension means NHWC and NCHW layouts are equivalent
132-
pass
133-
else:
134-
raise Exception("Unknown data_layout and input ndim" " combination for MultiThreshold.")
135-
136-
# Remember whether the shape has been modified to handle 1d or 3d data
137-
# layouts
138-
orig_shape = None
139-
# If the input tensor has dimensions not covered by the NC or NCWH data
140-
# layouts, the shape needs to be adapted such that it can be handled by
141-
# multithreshold.
142-
# TODO: Seems like a rather sketchy solution to support arbitrary data
143-
# layouts. This does not even validate the assumption of channel last
144-
# layout.
145-
if v.ndim not in {2, 4}:
146-
# Remember the original shape to be restored later
147-
orig_shape = v.shape
148-
# Assume last dimension to be the channel dimension C and reshape
149-
# into NC layout which is supported by multithreshold
150-
v = v.reshape((-1, v.shape[-1]))
151-
133+
channels_last = True if data_layout[-1] == "C" else False
152134
# calculate output
153-
output = multithreshold(v, thresholds, out_scale, out_bias)
154-
# setting context according to output
155-
if data_layout == "NHWC":
156-
if output.ndim == 4:
157-
# NCHW -> NHWC
158-
output = np.transpose(output, (0, 2, 3, 1))
159-
elif output.ndim == 2:
160-
# no HW dimension means NHWC and NCHW layouts are equivalent
161-
pass
162-
else:
163-
raise Exception("Unknown data_layout and output ndim" " combination for MultiThreshold.")
164-
165-
# If the shape has been modified to support arbitrary layouts, restore
166-
# the original shape
167-
# TODO: Part of the rather sketchy solution above.
168-
if orig_shape is not None:
169-
output = output.reshape(orig_shape)
170-
135+
orig_shape = v.shape
136+
output = multithreshold(v, thresholds, out_scale, out_bias, channels_last)
137+
assert output.shape == orig_shape, "Shape changed during thresholding!"
171138
context[node.output[0]] = output
172139

173140
def verify_node(self):

0 commit comments

Comments
 (0)