@@ -1218,33 +1218,33 @@ def kernel_body(x_smem, o_smem):
12181218 np .testing .assert_array_equal (kernel_fn (x )[:, :16 ], y [:, :16 ])
12191219
12201220 def test_emit_with_parallel_grid (self ):
1221- self .skipTest ("Enable once we support multiple levels of indexing" )
1222-
1223- num_steps = 4
1221+ num_steps1 = 4
1222+ num_steps2 = 5
12241223
12251224 def kernel (x_gmem , o_gmem ):
1226- gmem_slice = pl .ds ( pl . program_id (0 ) * 32 , 32 )
1225+ pid = pl .program_id (0 )
12271226 plgpu .emit_pipeline (
12281227 kernel_body ,
1229- in_specs = [pl .BlockSpec ((32 , 16 ), lambda i : (0 , i ))],
1230- out_specs = [pl .BlockSpec ((32 , 16 ), lambda i : (0 , i ))],
1231- grid = (num_steps ,),
1228+ in_specs = [pl .BlockSpec ((32 , 16 ), lambda i : (pid , i ))],
1229+ out_specs = [pl .BlockSpec ((32 , 16 ), lambda i : (pid , i ))],
1230+ grid = (num_steps2 ,),
12321231 max_concurrent_steps = 2 ,
1233- )(x_gmem . at [ gmem_slice ] , o_gmem . at [ gmem_slice ] )
1232+ )(x_gmem , o_gmem )
12341233
12351234 def kernel_body (x_smem , o_smem ):
12361235 o_smem [...] = x_smem [...] + 1.0
12371236
1238- x = jnp .arange (4 * 32 * num_steps * 16 )
1239- x = x .reshape (- 1 , num_steps * 16 ).astype (jnp .float32 )
1237+ x = jnp .arange (num_steps1 * 32 * num_steps2 * 16 )
1238+ x = x .reshape (- 1 , num_steps2 * 16 ).astype (jnp .float32 )
12401239 kernel_fn = pl .pallas_call (
12411240 kernel ,
12421241 in_specs = [pl .BlockSpec (memory_space = plgpu .GMEM )],
12431242 out_specs = pl .BlockSpec (memory_space = plgpu .GMEM ),
12441243 out_shape = jax .ShapeDtypeStruct (x .shape , x .dtype ),
1245- grid = (4 , 1 ),
1244+ grid = (num_steps1 , ),
12461245 )
1247- np .testing .assert_array_equal (kernel_fn (x ), x + 1.0 )
1246+ y = x + 1.0
1247+ np .testing .assert_array_equal (kernel_fn (x ), y )
12481248
12491249 def test_emit_with_2d_grid (self ):
12501250 num_steps1 = 4
0 commit comments