Skip to content

Commit eebe215

Browse files
author
Yaman Umuroglu
committed
[MultiThreshold] preserve output dtype in new implementation
1 parent a142fc5 commit eebe215

1 file changed

Lines changed: 3 additions & 1 deletion

File tree

src/qonnx/custom_op/general/multithreshold.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,9 @@ def multithreshold(v, thresholds, out_scale=None, out_bias=None, channels_last=F
7777
cmp = vm >= thresholds
7878
# replace last axis by count of nonzero values (True)
7979
# N_CT -> N_C
80-
ret = np.count_nonzero(cmp, axis=-1)
80+
# note the .astype cast to ensure type remains the same
81+
# TODO enforce ints instead?
82+
ret = np.count_nonzero(cmp, axis=-1).astype(v.dtype)
8183
if not channels_last:
8284
# move the channels axis back to index 1
8385
ret = np.moveaxis(ret, source=-1, destination=1)

0 commit comments

Comments
 (0)