Skip to content

Commit

Permalink
Improve test coverage (#472)
Browse files Browse the repository at this point in the history
* Add back more tests for 'simple_optimize_dag' which had inadvertently been dropped

* Remove unneeded None check

* Re-instate test_stragglers with shorter run time

* Remove redundant code
  • Loading branch information
tomwhite authored Jun 5, 2024
1 parent 4eb886f commit afadd83
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 18 deletions.
3 changes: 0 additions & 3 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,9 +571,6 @@ def map_blocks(
) -> "Array":
"""Apply a function to corresponding blocks from multiple input arrays."""

if drop_axis is None:
drop_axis = []

# Handle the case where an array is created by calling `map_blocks` with no input arrays
if len(args) == 0:
from cubed.array_api.creation_functions import empty_virtual_array
Expand Down
6 changes: 2 additions & 4 deletions cubed/primitive/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,13 +665,11 @@ def make_blockwise_key_function(
output: str,
out_indices: Sequence[Union[str, int]],
*arrind_pairs: Any,
numblocks: Optional[Dict[str, Tuple[int, ...]]] = None,
numblocks: Dict[str, Tuple[int, ...]],
new_axes: Optional[Dict[int, int]] = None,
) -> Callable[[List[int]], Any]:
"""Make a function that is the equivalent of make_blockwise_graph."""

if numblocks is None:
raise ValueError("Missing required numblocks argument.")
new_axes = new_axes or {}
argpairs = list(toolz.partition(2, arrind_pairs))

Expand Down Expand Up @@ -723,7 +721,7 @@ def make_blockwise_key_function_flattened(
output: str,
out_indices: Sequence[Union[str, int]],
*arrind_pairs: Any,
numblocks: Optional[Dict[str, Tuple[int, ...]]] = None,
numblocks: Dict[str, Tuple[int, ...]],
new_axes: Optional[Dict[int, int]] = None,
) -> Callable[[List[int]], Any]:
# TODO: make this a part of make_blockwise_key_function?
Expand Down
3 changes: 1 addition & 2 deletions cubed/tests/runtime/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,10 @@ def test_failure(tmp_path, timing_map, n_tasks, retries, use_backups):
@pytest.mark.parametrize(
"timing_map, n_tasks, retries",
[
({0: [60]}, 10, 2),
({0: [10]}, 10, 2),
],
)
# fmt: on
@pytest.mark.skip(reason="This passes, but Python will not exit until the slow task is done.")
def test_stragglers(tmp_path, timing_map, n_tasks, retries):
outputs = asyncio.run(
run_test(
Expand Down
37 changes: 28 additions & 9 deletions cubed/tests/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ def spec(tmp_path):
return cubed.Spec(tmp_path, allowed_mem=100000)


def test_fusion(spec):
@pytest.mark.parametrize(
"opt_fn", [None, simple_optimize_dag, multiple_inputs_optimize_dag]
)
def test_fusion(spec, opt_fn):
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
b = xp.negative(a)
c = xp.astype(b, np.float32)
Expand All @@ -43,12 +46,20 @@ def test_fusion(spec):
)
num_arrays = 2 # a, d
num_created_arrays = 1 # d (a is not created on disk)
assert d.plan.num_arrays(optimize_graph=True) == num_arrays
assert d.plan.num_tasks(optimize_graph=True) == num_created_arrays + 4
assert d.plan.total_nbytes_written(optimize_graph=True) == d.nbytes
assert (
d.plan.num_arrays(optimize_graph=True, optimize_function=opt_fn) == num_arrays
)
assert (
d.plan.num_tasks(optimize_graph=True, optimize_function=opt_fn)
== num_created_arrays + 4
)
assert (
d.plan.total_nbytes_written(optimize_graph=True, optimize_function=opt_fn)
== d.nbytes
)

task_counter = TaskCounter()
result = d.compute(callbacks=[task_counter])
result = d.compute(optimize_function=opt_fn, callbacks=[task_counter])
assert task_counter.value == num_created_arrays + 4

assert_array_equal(
Expand All @@ -57,7 +68,10 @@ def test_fusion(spec):
)


def test_fusion_transpose(spec):
@pytest.mark.parametrize(
"opt_fn", [None, simple_optimize_dag, multiple_inputs_optimize_dag]
)
def test_fusion_transpose(spec, opt_fn):
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
b = xp.negative(a)
c = xp.astype(b, np.float32)
Expand All @@ -66,10 +80,13 @@ def test_fusion_transpose(spec):
num_created_arrays = 3 # b, c, d
assert d.plan.num_tasks(optimize_graph=False) == num_created_arrays + 12
num_created_arrays = 1 # d
assert d.plan.num_tasks(optimize_graph=True) == num_created_arrays + 4
assert (
d.plan.num_tasks(optimize_graph=True, optimize_function=opt_fn)
== num_created_arrays + 4
)

task_counter = TaskCounter()
result = d.compute(callbacks=[task_counter])
result = d.compute(optimize_function=opt_fn, callbacks=[task_counter])
assert task_counter.value == num_created_arrays + 4

assert_array_equal(
Expand All @@ -81,6 +98,7 @@ def test_fusion_transpose(spec):
def test_fusion_map_direct(spec):
# test that operations after a map_direct operation (indexing) can be fused
# with the map_direct operation
# this is only true for the (default) multiple_inputs_optimize_dag optimize function
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
b = a[1:, :]
c = xp.negative(b) # should be fused with b
Expand All @@ -102,6 +120,7 @@ def test_fusion_map_direct(spec):

def test_no_fusion(spec):
# b can't be fused with c because d also depends on b
# this is only true for the simple_optimize_dag optimize function
a = xp.ones((2, 2), chunks=(2, 2), spec=spec)
b = xp.positive(a)
c = xp.positive(b)
Expand All @@ -126,7 +145,7 @@ def test_no_fusion_multiple_edges(spec):
c = xp.asarray(b)
# b and c are the same array, so d has a single dependency
# with multiple edges
# this should not be fused under the current logic
# this should not be fused under the current logic in simple_optimize_dag
d = xp.equal(b, c)

opt_fn = simple_optimize_dag
Expand Down

0 comments on commit afadd83

Please sign in to comment.