Skip to content

Commit afadd83

Browse files
authored
Improve test coverage (#472)
* Add back more tests for 'simple_optimize_dag' which had inadvertently been dropped * Remove unneeded None check * Re-instate test_stragglers with shorter run time * Remove redundant code
1 parent 4eb886f commit afadd83

File tree

4 files changed

+31
-18
lines changed

4 files changed

+31
-18
lines changed

cubed/core/ops.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -571,9 +571,6 @@ def map_blocks(
571571
) -> "Array":
572572
"""Apply a function to corresponding blocks from multiple input arrays."""
573573

574-
if drop_axis is None:
575-
drop_axis = []
576-
577574
# Handle the case where an array is created by calling `map_blocks` with no input arrays
578575
if len(args) == 0:
579576
from cubed.array_api.creation_functions import empty_virtual_array

cubed/primitive/blockwise.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -665,13 +665,11 @@ def make_blockwise_key_function(
665665
output: str,
666666
out_indices: Sequence[Union[str, int]],
667667
*arrind_pairs: Any,
668-
numblocks: Optional[Dict[str, Tuple[int, ...]]] = None,
668+
numblocks: Dict[str, Tuple[int, ...]],
669669
new_axes: Optional[Dict[int, int]] = None,
670670
) -> Callable[[List[int]], Any]:
671671
"""Make a function that is the equivalent of make_blockwise_graph."""
672672

673-
if numblocks is None:
674-
raise ValueError("Missing required numblocks argument.")
675673
new_axes = new_axes or {}
676674
argpairs = list(toolz.partition(2, arrind_pairs))
677675

@@ -723,7 +721,7 @@ def make_blockwise_key_function_flattened(
723721
output: str,
724722
out_indices: Sequence[Union[str, int]],
725723
*arrind_pairs: Any,
726-
numblocks: Optional[Dict[str, Tuple[int, ...]]] = None,
724+
numblocks: Dict[str, Tuple[int, ...]],
727725
new_axes: Optional[Dict[int, int]] = None,
728726
) -> Callable[[List[int]], Any]:
729727
# TODO: make this a part of make_blockwise_key_function?

cubed/tests/runtime/test_local.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,10 @@ def test_failure(tmp_path, timing_map, n_tasks, retries, use_backups):
8383
@pytest.mark.parametrize(
8484
"timing_map, n_tasks, retries",
8585
[
86-
({0: [60]}, 10, 2),
86+
({0: [10]}, 10, 2),
8787
],
8888
)
8989
# fmt: on
90-
@pytest.mark.skip(reason="This passes, but Python will not exit until the slow task is done.")
9190
def test_stragglers(tmp_path, timing_map, n_tasks, retries):
9291
outputs = asyncio.run(
9392
run_test(

cubed/tests/test_optimization.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@ def spec(tmp_path):
2727
return cubed.Spec(tmp_path, allowed_mem=100000)
2828

2929

30-
def test_fusion(spec):
30+
@pytest.mark.parametrize(
31+
"opt_fn", [None, simple_optimize_dag, multiple_inputs_optimize_dag]
32+
)
33+
def test_fusion(spec, opt_fn):
3134
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
3235
b = xp.negative(a)
3336
c = xp.astype(b, np.float32)
@@ -43,12 +46,20 @@ def test_fusion(spec):
4346
)
4447
num_arrays = 2 # a, d
4548
num_created_arrays = 1 # d (a is not created on disk)
46-
assert d.plan.num_arrays(optimize_graph=True) == num_arrays
47-
assert d.plan.num_tasks(optimize_graph=True) == num_created_arrays + 4
48-
assert d.plan.total_nbytes_written(optimize_graph=True) == d.nbytes
49+
assert (
50+
d.plan.num_arrays(optimize_graph=True, optimize_function=opt_fn) == num_arrays
51+
)
52+
assert (
53+
d.plan.num_tasks(optimize_graph=True, optimize_function=opt_fn)
54+
== num_created_arrays + 4
55+
)
56+
assert (
57+
d.plan.total_nbytes_written(optimize_graph=True, optimize_function=opt_fn)
58+
== d.nbytes
59+
)
4960

5061
task_counter = TaskCounter()
51-
result = d.compute(callbacks=[task_counter])
62+
result = d.compute(optimize_function=opt_fn, callbacks=[task_counter])
5263
assert task_counter.value == num_created_arrays + 4
5364

5465
assert_array_equal(
@@ -57,7 +68,10 @@ def test_fusion(spec):
5768
)
5869

5970

60-
def test_fusion_transpose(spec):
71+
@pytest.mark.parametrize(
72+
"opt_fn", [None, simple_optimize_dag, multiple_inputs_optimize_dag]
73+
)
74+
def test_fusion_transpose(spec, opt_fn):
6175
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
6276
b = xp.negative(a)
6377
c = xp.astype(b, np.float32)
@@ -66,10 +80,13 @@ def test_fusion_transpose(spec):
6680
num_created_arrays = 3 # b, c, d
6781
assert d.plan.num_tasks(optimize_graph=False) == num_created_arrays + 12
6882
num_created_arrays = 1 # d
69-
assert d.plan.num_tasks(optimize_graph=True) == num_created_arrays + 4
83+
assert (
84+
d.plan.num_tasks(optimize_graph=True, optimize_function=opt_fn)
85+
== num_created_arrays + 4
86+
)
7087

7188
task_counter = TaskCounter()
72-
result = d.compute(callbacks=[task_counter])
89+
result = d.compute(optimize_function=opt_fn, callbacks=[task_counter])
7390
assert task_counter.value == num_created_arrays + 4
7491

7592
assert_array_equal(
@@ -81,6 +98,7 @@ def test_fusion_transpose(spec):
8198
def test_fusion_map_direct(spec):
8299
# test that operations after a map_direct operation (indexing) can be fused
83100
# with the map_direct operation
101+
# this is only true for the (default) multiple_inputs_optimize_dag optimize function
84102
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
85103
b = a[1:, :]
86104
c = xp.negative(b) # should be fused with b
@@ -102,6 +120,7 @@ def test_fusion_map_direct(spec):
102120

103121
def test_no_fusion(spec):
104122
# b can't be fused with c because d also depends on b
123+
# this is only true for the simple_optimize_dag optimize function
105124
a = xp.ones((2, 2), chunks=(2, 2), spec=spec)
106125
b = xp.positive(a)
107126
c = xp.positive(b)
@@ -126,7 +145,7 @@ def test_no_fusion_multiple_edges(spec):
126145
c = xp.asarray(b)
127146
# b and c are the same array, so d has a single dependency
128147
# with multiple edges
129-
# this should not be fused under the current logic
148+
# this should not be fused under the current logic in simple_optimize_dag
130149
d = xp.equal(b, c)
131150

132151
opt_fn = simple_optimize_dag

0 commit comments

Comments
 (0)