@@ -27,7 +27,10 @@ def spec(tmp_path):
27
27
return cubed .Spec (tmp_path , allowed_mem = 100000 )
28
28
29
29
30
- def test_fusion (spec ):
30
+ @pytest .mark .parametrize (
31
+ "opt_fn" , [None , simple_optimize_dag , multiple_inputs_optimize_dag ]
32
+ )
33
+ def test_fusion (spec , opt_fn ):
31
34
a = xp .asarray ([[1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 8 , 9 ]], chunks = (2 , 2 ), spec = spec )
32
35
b = xp .negative (a )
33
36
c = xp .astype (b , np .float32 )
@@ -43,12 +46,20 @@ def test_fusion(spec):
43
46
)
44
47
num_arrays = 2 # a, d
45
48
num_created_arrays = 1 # d (a is not created on disk)
46
- assert d .plan .num_arrays (optimize_graph = True ) == num_arrays
47
- assert d .plan .num_tasks (optimize_graph = True ) == num_created_arrays + 4
48
- assert d .plan .total_nbytes_written (optimize_graph = True ) == d .nbytes
49
+ assert (
50
+ d .plan .num_arrays (optimize_graph = True , optimize_function = opt_fn ) == num_arrays
51
+ )
52
+ assert (
53
+ d .plan .num_tasks (optimize_graph = True , optimize_function = opt_fn )
54
+ == num_created_arrays + 4
55
+ )
56
+ assert (
57
+ d .plan .total_nbytes_written (optimize_graph = True , optimize_function = opt_fn )
58
+ == d .nbytes
59
+ )
49
60
50
61
task_counter = TaskCounter ()
51
- result = d .compute (callbacks = [task_counter ])
62
+ result = d .compute (optimize_function = opt_fn , callbacks = [task_counter ])
52
63
assert task_counter .value == num_created_arrays + 4
53
64
54
65
assert_array_equal (
@@ -57,7 +68,10 @@ def test_fusion(spec):
57
68
)
58
69
59
70
60
- def test_fusion_transpose (spec ):
71
+ @pytest .mark .parametrize (
72
+ "opt_fn" , [None , simple_optimize_dag , multiple_inputs_optimize_dag ]
73
+ )
74
+ def test_fusion_transpose (spec , opt_fn ):
61
75
a = xp .asarray ([[1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 8 , 9 ]], chunks = (2 , 2 ), spec = spec )
62
76
b = xp .negative (a )
63
77
c = xp .astype (b , np .float32 )
@@ -66,10 +80,13 @@ def test_fusion_transpose(spec):
66
80
num_created_arrays = 3 # b, c, d
67
81
assert d .plan .num_tasks (optimize_graph = False ) == num_created_arrays + 12
68
82
num_created_arrays = 1 # d
69
- assert d .plan .num_tasks (optimize_graph = True ) == num_created_arrays + 4
83
+ assert (
84
+ d .plan .num_tasks (optimize_graph = True , optimize_function = opt_fn )
85
+ == num_created_arrays + 4
86
+ )
70
87
71
88
task_counter = TaskCounter ()
72
- result = d .compute (callbacks = [task_counter ])
89
+ result = d .compute (optimize_function = opt_fn , callbacks = [task_counter ])
73
90
assert task_counter .value == num_created_arrays + 4
74
91
75
92
assert_array_equal (
@@ -81,6 +98,7 @@ def test_fusion_transpose(spec):
81
98
def test_fusion_map_direct (spec ):
82
99
# test that operations after a map_direct operation (indexing) can be fused
83
100
# with the map_direct operation
101
+ # this is only true for the (default) multiple_inputs_optimize_dag optimize function
84
102
a = xp .asarray ([[1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 8 , 9 ]], chunks = (2 , 2 ), spec = spec )
85
103
b = a [1 :, :]
86
104
c = xp .negative (b ) # should be fused with b
@@ -102,6 +120,7 @@ def test_fusion_map_direct(spec):
102
120
103
121
def test_no_fusion (spec ):
104
122
# b can't be fused with c because d also depends on b
123
+ # this is only true for the simple_optimize_dag optimize function
105
124
a = xp .ones ((2 , 2 ), chunks = (2 , 2 ), spec = spec )
106
125
b = xp .positive (a )
107
126
c = xp .positive (b )
@@ -126,7 +145,7 @@ def test_no_fusion_multiple_edges(spec):
126
145
c = xp .asarray (b )
127
146
# b and c are the same array, so d has a single dependency
128
147
# with multiple edges
129
- # this should not be fused under the current logic
148
+ # this should not be fused under the current logic in simple_optimize_dag
130
149
d = xp .equal (b , c )
131
150
132
151
opt_fn = simple_optimize_dag
0 commit comments