Skip to content

Commit a424878

Browse files
committed
Adjust stateless tests
Signed-off-by: Michał Szołucha <[email protected]>
1 parent 1c14735 commit a424878

File tree

1 file changed

+26
-19
lines changed

1 file changed

+26
-19
lines changed

dali/test/python/checkpointing/test_dali_stateless_operators.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -213,10 +213,10 @@ def test_resize_stateless(device):
213213
check_single_input(fn.resize, device, resize_x=50, resize_y=50)
214214

215215

216-
@params("cpu", "gpu")
217-
@stateless_signed_off("experimental.tensor_resize")
218-
def test_tensor_resize_stateless(device):
219-
check_single_input(fn.experimental.tensor_resize, device, axes=[0, 1], sizes=[40, 40])
216+
@cartesian_params(("cpu", "gpu"), (fn.experimental.tensor_resize, fn.tensor_resize))
217+
@stateless_signed_off("experimental.tensor_resize", "tensor_resize")
218+
def test_tensor_resize_stateless(device, tensor_resize_op):
219+
check_single_input(tensor_resize_op, device, axes=[0, 1], sizes=[40, 40])
220220

221221

222222
@params("cpu", "gpu")
@@ -334,10 +334,10 @@ def test_reductions_variance_stateless(device):
334334
check_single_input(lambda x, **kwargs: fn.reductions.variance(x, 5.0, **kwargs), device)
335335

336336

337-
@params("cpu", "gpu")
338-
@stateless_signed_off("experimental.equalize")
339-
def test_equalize_stateless(device):
340-
check_single_input(fn.experimental.equalize, device)
337+
@cartesian_params(("cpu", "gpu"), (fn.experimental.equalize, fn.equalize))
338+
@stateless_signed_off("experimental.equalize", "equalize")
339+
def test_equalize_stateless(device, equalize_op):
340+
check_single_input(equalize_op, device)
341341

342342

343343
@stateless_signed_off("transforms.crop")
@@ -468,11 +468,11 @@ def test_sphere_stateless(device):
468468
check_single_input(fn.sphere, device)
469469

470470

471-
@params("cpu", "gpu")
472-
@stateless_signed_off("experimental.filter")
473-
def test_filter_stateless(device):
471+
@cartesian_params(("cpu", "gpu"), (fn.experimental.filter, fn.filter))
472+
@stateless_signed_off("experimental.filter", "filter")
473+
def test_filter_stateless(device, filter_op):
474474
check_single_input(
475-
lambda x, **kwargs: fn.experimental.filter(x, np.full((3, 3), 1 / 9), **kwargs),
475+
lambda x, **kwargs: filter_op(x, np.full((3, 3), 1 / 9), **kwargs),
476476
device,
477477
)
478478

@@ -494,15 +494,15 @@ def pipeline_factory():
494494
check_is_pipeline_stateless(pipeline_factory)
495495

496496

497-
@params("cpu", "gpu")
498-
@stateless_signed_off("experimental.debayer")
499-
def test_debayer_stateless(device):
497+
@cartesian_params(("cpu", "gpu"), (fn.experimental.debayer, fn.debayer))
498+
@stateless_signed_off("experimental.debayer", "debayer")
499+
def test_debayer_stateless(device, debayer_op):
500500
@pipeline_def(enable_checkpointing=True)
501501
def pipeline_factory():
502502
data = fn.external_source(source=RandomBatch((40, 40)), layout="HW", batch=True)
503503
if device == "gpu":
504504
data = data.gpu()
505-
return fn.experimental.debayer(data, blue_position=[0, 0])
505+
return debayer_op(data, blue_position=[0, 0])
506506

507507
check_is_pipeline_stateless(pipeline_factory)
508508

@@ -772,21 +772,28 @@ def wrapper(x, **kwargs):
772772

773773

774774
@attr("numba")
775-
@stateless_signed_off("experimental.numba_function")
776-
def test_numba_function_stateless():
775+
@params(True, False)
776+
@stateless_signed_off("experimental.numba_function", "numba_function")
777+
def test_numba_function_stateless(use_experimental):
777778
import nvidia.dali.plugin.numba as dali_numba
778779

779780
check_numba_compatibility_cpu()
780781

781782
def double_sample(out_sample, in_sample):
782783
out_sample[:] = 2 * in_sample[:]
783784

785+
numba_function_op = (
786+
dali_numba.fn.experimental.numba_function
787+
if use_experimental
788+
else dali_numba.fn.numba_function
789+
)
790+
784791
@pipeline_def(batch_size=2, device_id=0, num_threads=4, enable_checkpointing=True)
785792
def numba_pipe():
786793
forty_two = fn.external_source(
787794
source=lambda x: np.full((2,), 42, dtype=np.uint8), batch=False
788795
)
789-
out = dali_numba.fn.experimental.numba_function(
796+
out = numba_function_op(
790797
forty_two,
791798
run_fn=double_sample,
792799
out_types=[types.DALIDataType.UINT8],

0 commit comments

Comments
 (0)