Skip to content

Commit 9584ee3

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas:mosaic_gpu] Avoid using multiple indexers in the parallel grid test
Turns out we can mix parallel grid with `plgpu.emit_pipeline` without doing indexing at all! PiperOrigin-RevId: 698442820
1 parent 2c9b917 commit 9584ee3

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

tests/pallas/mosaic_gpu_test.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,33 +1218,33 @@ def kernel_body(x_smem, o_smem):
12181218
np.testing.assert_array_equal(kernel_fn(x)[:, :16], y[:, :16])
12191219

12201220
def test_emit_with_parallel_grid(self):
1221-
self.skipTest("Enable once we support multiple levels of indexing")
1222-
1223-
num_steps = 4
1221+
num_steps1 = 4
1222+
num_steps2 = 5
12241223

12251224
def kernel(x_gmem, o_gmem):
1226-
gmem_slice = pl.ds(pl.program_id(0) * 32, 32)
1225+
pid = pl.program_id(0)
12271226
plgpu.emit_pipeline(
12281227
kernel_body,
1229-
in_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))],
1230-
out_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))],
1231-
grid=(num_steps,),
1228+
in_specs=[pl.BlockSpec((32, 16), lambda i: (pid, i))],
1229+
out_specs=[pl.BlockSpec((32, 16), lambda i: (pid, i))],
1230+
grid=(num_steps2,),
12321231
max_concurrent_steps=2,
1233-
)(x_gmem.at[gmem_slice], o_gmem.at[gmem_slice])
1232+
)(x_gmem, o_gmem)
12341233

12351234
def kernel_body(x_smem, o_smem):
12361235
o_smem[...] = x_smem[...] + 1.0
12371236

1238-
x = jnp.arange(4 * 32 * num_steps * 16)
1239-
x = x.reshape(-1, num_steps * 16).astype(jnp.float32)
1237+
x = jnp.arange(num_steps1 * 32 * num_steps2 * 16)
1238+
x = x.reshape(-1, num_steps2 * 16).astype(jnp.float32)
12401239
kernel_fn = pl.pallas_call(
12411240
kernel,
12421241
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
12431242
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
12441243
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
1245-
grid=(4, 1),
1244+
grid=(num_steps1,),
12461245
)
1247-
np.testing.assert_array_equal(kernel_fn(x), x + 1.0)
1246+
y = x + 1.0
1247+
np.testing.assert_array_equal(kernel_fn(x), y)
12481248

12491249
def test_emit_with_2d_grid(self):
12501250
num_steps1 = 4

0 commit comments

Comments
 (0)