|
26 | 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
27 | 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
28 | 28 |
|
| 29 | +import pytest |
| 30 | + |
29 | 31 | import numpy as np |
30 | 32 | import time |
31 | 33 |
|
@@ -296,6 +298,47 @@ def test_multithreshold(): |
296 | 298 | assert (results_scaled == outputs_scaled).all() |
297 | 299 |
|
298 | 300 |
|
| 301 | +@pytest.mark.parametrize("input_rank", [2, 3, 4]) |
| 302 | +@pytest.mark.parametrize("num_channels", [1, 3]) |
| 303 | +@pytest.mark.parametrize("num_steps", [1, 4, 7]) |
| 304 | +@pytest.mark.parametrize("threshold_granularity", ["per_channel", "global"]) |
| 305 | +@pytest.mark.parametrize("use_scale_bias", [False, True]) |
| 306 | +def test_multithreshold_parametrized(input_rank, num_channels, num_steps, threshold_granularity, use_scale_bias): |
| 307 | + np.random.seed(0) |
| 308 | + N = 2 |
| 309 | + |
| 310 | + # Determine shape |
| 311 | + shape = [N, num_channels] |
| 312 | + spatial = [] |
| 313 | + if input_rank == 3: |
| 314 | + spatial = [5] |
| 315 | + elif input_rank == 4: |
| 316 | + spatial = [3, 3] |
| 317 | + shape += spatial |
| 318 | + |
| 319 | + # Generate random input data |
| 320 | + v = np.random.uniform(-3, 5, size=shape).astype(np.float32) |
| 321 | + |
| 322 | + # Generate thresholds |
| 323 | + if threshold_granularity == "per_channel": |
| 324 | + thresholds = np.sort(np.random.uniform(-2, 4, size=(num_channels, num_steps)).astype(np.float32), axis=1) |
| 325 | + else: # global |
| 326 | + thresholds = np.sort(np.random.uniform(-2, 4, size=(1, num_steps)).astype(np.float32), axis=1) |
| 327 | + |
| 328 | + # Generate scale and bias |
| 329 | + if use_scale_bias: |
| 330 | + out_scale = np.random.uniform(0.5, 2.0) |
| 331 | + out_bias = np.random.uniform(-1.0, 1.0) |
| 332 | + else: |
| 333 | + out_scale = None |
| 334 | + out_bias = None |
| 335 | + |
| 336 | + ref = multithreshold_elementwise(v, thresholds, out_scale, out_bias) |
| 337 | + result = multithreshold(v, thresholds, out_scale, out_bias) |
| 338 | + |
| 339 | + np.testing.assert_allclose(result, ref, rtol=1e-6, atol=1e-6) |
| 340 | + |
| 341 | + |
299 | 342 | def multithreshold_performance(): |
300 | 343 | # performance and random test |
301 | 344 | np.random.seed(0) |
|
0 commit comments