Skip to content

Commit 920e253

Browse files
author
Yaman Umuroglu
committed
[Test] add parametrized test for MultiThreshold
1 parent 0ed26fc commit 920e253

1 file changed

Lines changed: 43 additions & 0 deletions

File tree

tests/custom_op/test_multithreshold.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2727
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2828

29+
import pytest
30+
2931
import numpy as np
3032
import time
3133

@@ -296,6 +298,47 @@ def test_multithreshold():
296298
assert (results_scaled == outputs_scaled).all()
297299

298300

301+
@pytest.mark.parametrize("input_rank", [2, 3, 4])
302+
@pytest.mark.parametrize("num_channels", [1, 3])
303+
@pytest.mark.parametrize("num_steps", [1, 4, 7])
304+
@pytest.mark.parametrize("threshold_granularity", ["per_channel", "global"])
305+
@pytest.mark.parametrize("use_scale_bias", [False, True])
306+
def test_multithreshold_parametrized(input_rank, num_channels, num_steps, threshold_granularity, use_scale_bias):
307+
np.random.seed(0)
308+
N = 2
309+
310+
# Determine shape
311+
shape = [N, num_channels]
312+
spatial = []
313+
if input_rank == 3:
314+
spatial = [5]
315+
elif input_rank == 4:
316+
spatial = [3, 3]
317+
shape += spatial
318+
319+
# Generate random input data
320+
v = np.random.uniform(-3, 5, size=shape).astype(np.float32)
321+
322+
# Generate thresholds
323+
if threshold_granularity == "per_channel":
324+
thresholds = np.sort(np.random.uniform(-2, 4, size=(num_channels, num_steps)).astype(np.float32), axis=1)
325+
else: # global
326+
thresholds = np.sort(np.random.uniform(-2, 4, size=(1, num_steps)).astype(np.float32), axis=1)
327+
328+
# Generate scale and bias
329+
if use_scale_bias:
330+
out_scale = np.random.uniform(0.5, 2.0)
331+
out_bias = np.random.uniform(-1.0, 1.0)
332+
else:
333+
out_scale = None
334+
out_bias = None
335+
336+
ref = multithreshold_elementwise(v, thresholds, out_scale, out_bias)
337+
result = multithreshold(v, thresholds, out_scale, out_bias)
338+
339+
np.testing.assert_allclose(result, ref, rtol=1e-6, atol=1e-6)
340+
341+
299342
def multithreshold_performance():
300343
# performance and random test
301344
np.random.seed(0)

0 commit comments

Comments
 (0)