Skip to content

Commit b6996cb

Browse files
justinjfuGoogle-ML-Automation
authored andcommitted
[Pallas][Mosaic GPU] Use NDLoopInfo as the argument to the loop body of plgpu.nd_loop
PiperOrigin-RevId: 811413884
1 parent 85edead commit b6996cb

File tree

8 files changed

+59
-52
lines changed

8 files changed

+59
-52
lines changed

jax/_src/pallas/mosaic_gpu/helpers.py

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Helpers for Pallas Mosaic GPU kernels."""
1616

1717
from collections.abc import Callable, Hashable, Sequence
18+
import dataclasses
1819
import functools
1920
import math
2021
from typing import TypeVar, overload
@@ -27,15 +28,18 @@
2728
_T = TypeVar("_T")
2829

2930

30-
@overload
31-
def nd_loop(
32-
grid: Sequence[int],
33-
*,
34-
collective_axes: Sequence[Hashable] | Hashable,
35-
tiling: Sequence[int] | None = None,
36-
init_carry: None = None
37-
) -> Callable[[Callable[[Sequence[jax.Array]], None]], None]:
38-
...
31+
@dataclasses.dataclass(frozen=True, eq=False)
32+
class NDLoopInfo:
33+
"""Container dataclass for loop iteration information.
34+
35+
Attributes:
36+
index: The grid indices corresponding to the current loop iteration.
37+
local_index: The local iteration index.
38+
num_local_steps: The total number of local iterations to run.
39+
"""
40+
index: tuple[jax.Array, ...]
41+
local_index: jax.Array | int
42+
num_local_steps: jax.Array | int
3943

4044

4145
@overload
@@ -44,29 +48,30 @@ def nd_loop(
4448
*,
4549
collective_axes: Sequence[Hashable] | Hashable,
4650
tiling: Sequence[int] | None = None,
47-
init_carry: _T
48-
) -> Callable[[Callable[[Sequence[jax.Array], _T], _T]], _T]:
51+
init_carry: None = None
52+
) -> Callable[[Callable[[NDLoopInfo], None]], None]:
4953
...
5054

5155

52-
# TODO(justinfu): Fix the type signature to include both carry and wave_step.
5356
@overload
5457
def nd_loop(
5558
grid: Sequence[int],
5659
*,
5760
collective_axes: Sequence[Hashable] | Hashable,
5861
tiling: Sequence[int] | None = None,
59-
include_wave_step: bool
60-
) -> Callable[[Callable[[Sequence[jax.Array], jax.Array], None]], None]:
62+
init_carry: _T
63+
) -> Callable[[Callable[[NDLoopInfo, _T], _T]], _T]:
6164
...
6265

6366

64-
def nd_loop(grid, *, collective_axes,
65-
tiling=None,
66-
init_carry=None,
67-
include_wave_step=False):
67+
def nd_loop(grid, *, collective_axes, tiling=None, init_carry=None):
6868
"""A loop over a multi-dimensional grid partitioned along the given axes.
6969
70+
The body of the loop a single argument `loop_info` which is an NDLoopInfo
71+
object containing index and iteration information. However if a carry is
72+
specified, the body will expect a second keyword argument `carry` containing
73+
the loop carry.
74+
7075
For example, if ``collective_axes`` is ``"x"`` with :func:`lax.axis_size`
7176
equal to 4 and the grid is (2, 3), the implementation would produce the
7277
following iteration order
@@ -98,10 +103,6 @@ def nd_loop(grid, *, collective_axes,
98103
take and return the carry. If it's ``None`` then no carry argument is
99104
expected.
100105
101-
If ``include_wave_step`` is True then the body will be called with an
102-
additional ``wave_step`` keyword argument that specifies the current
103-
iteration local to the thread.
104-
105106
See also:
106107
- :func:`jax.experimental.pallas.loop`: A loop over a single dimension.
107108
"""
@@ -141,12 +142,15 @@ def wrapper(wave_step, carry):
141142
untiled_index.append(sub_idx + tile_idx * tile_dim)
142143
index = untiled_index
143144

144-
if include_wave_step:
145-
body = functools.partial(body, wave_step=wave_step)
145+
loop_info = NDLoopInfo(
146+
index=tuple(index),
147+
local_index=wave_step,
148+
num_local_steps=upper
149+
)
146150
if init_carry is None:
147-
body(tuple(index))
151+
body(loop_info)
148152
else:
149-
return body(tuple(index), carry=carry)
153+
return body(loop_info, carry=carry)
150154

151155
upper = lax.div(grid_size, axis_size) + lax.convert_element_type(
152156
axis_index < grid_size % axis_size, axis_index.dtype

jax/experimental/pallas/mosaic_gpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from jax._src.pallas.mosaic_gpu.helpers import find_swizzle as find_swizzle
4646
from jax._src.pallas.mosaic_gpu.helpers import format_tcgen05_sparse_metadata as format_tcgen05_sparse_metadata
4747
from jax._src.pallas.mosaic_gpu.helpers import nd_loop as nd_loop
48+
from jax._src.pallas.mosaic_gpu.helpers import NDLoopInfo as NDLoopInfo
4849
from jax._src.pallas.mosaic_gpu.helpers import planar_snake as planar_snake
4950
from jax._src.pallas.mosaic_gpu.pipeline import emit_pipeline as emit_pipeline
5051
from jax._src.pallas.mosaic_gpu.pipeline import emit_pipeline_warp_specialized as emit_pipeline_warp_specialized

jax/experimental/pallas/ops/gpu/blackwell_matmul_mgpu.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,10 @@ def kernel(a_gmem, b_gmem, out_gmem,
106106
is_lead_block = cluster_idx == 0
107107

108108
@plgpu.nd_loop((m_iters * n_iters,),
109-
collective_axes="sm",
110-
include_wave_step=True)
111-
def mn_loop(idx, wave_step): # pylint: disable=unused-variable
112-
(lin_idx,) = idx
109+
collective_axes="sm")
110+
def mn_loop(loop_info: plgpu.NDLoopInfo): # pylint: disable=unused-variable
111+
(lin_idx,) = loop_info.index
112+
local_index = loop_info.local_index
113113
m_index, n_index = plgpu.planar_snake(
114114
lin_idx,
115115
(m_iters, n_iters),
@@ -121,7 +121,7 @@ def mn_loop(idx, wave_step): # pylint: disable=unused-variable
121121
block_slice_m = pl.ds(block_m_index * block_tile_m, block_tile_m)
122122
slice_m = pl.ds(m_index * tile_m, tile_m)
123123
slice_n = pl.ds(n_index * tile_n, tile_n)
124-
acc_slot = lax.rem(wave_step, jnp.int32(2))
124+
acc_slot = lax.rem(local_index, jnp.int32(2))
125125

126126
@pl.when(wg_idx == COMPUTE_WG)
127127
def _():
@@ -134,7 +134,7 @@ def _loop_body(ki, _):
134134
slice_k = pl.ds(ki * tile_k, tile_k)
135135
slot = lax.rem(ki, max_concurrent_steps)
136136
@pl.when(jnp.logical_or(ki >= max_concurrent_steps,
137-
wave_step > 0))
137+
local_index > 0))
138138
def _():
139139
plgpu.barrier_wait(consumed_barrier.at[slot])
140140
plgpu.copy_gmem_to_smem(
@@ -153,7 +153,7 @@ def _():
153153
)
154154
lax.fori_loop(0, k_iters, _loop_body, None)
155155

156-
@pl.when(jnp.logical_and(warp_id == MMA_WARP, wave_step > 1))
156+
@pl.when(jnp.logical_and(warp_id == MMA_WARP, local_index > 1))
157157
def _wait_store():
158158
plgpu.barrier_wait(store_done_barrier.at[acc_slot])
159159
@pl.when(jnp.logical_and(warp_id == MMA_WARP, is_lead_block))

jax/experimental/pallas/ops/gpu/blackwell_ragged_dot_mgpu.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def do_matmul(a_gmem,
5151
grid_indices: Sequence[jax.Array],
5252
wg_axis: str,
5353
collective_axes: tuple[str, ...],
54-
wave_step: jax.Array,
54+
local_index: jax.Array,
5555
config: TuningConfig,
5656
group_info: ragged_dot_mgpu.GroupInfo,
5757
a_smem, b_smem, acc_tmem, acc_smem,
@@ -91,7 +91,7 @@ def do_matmul(a_gmem,
9191
block_slice_m = pl.ds(block_m_index * block_tile_m, block_tile_m)
9292
slice_m = pl.ds(m_index * tile_m, tile_m)
9393
slice_n = pl.ds(n_index * tile_n, tile_n)
94-
acc_slot = lax.rem(wave_step, jnp.int32(2))
94+
acc_slot = lax.rem(local_index, jnp.int32(2))
9595
regs_layout = plgpu.Layout.TCGEN05
9696

9797
@pl.when(wg_idx == COMPUTE_WG)
@@ -106,7 +106,7 @@ def _loop_body(ki, _):
106106
slice_k = pl.ds(ki * tile_k, tile_k)
107107
slot = lax.rem(ki, max_concurrent_steps)
108108
@pl.when(jnp.logical_or(ki >= max_concurrent_steps,
109-
wave_step > 0))
109+
local_index > 0))
110110
def _():
111111
plgpu.barrier_wait(consumed_barrier.at[slot])
112112
plgpu.copy_gmem_to_smem(
@@ -125,7 +125,7 @@ def _():
125125
)
126126
lax.fori_loop(0, k_iters, _loop_body, None)
127127

128-
@pl.when(jnp.logical_and(warp_id == MMA_WARP, wave_step > 1))
128+
@pl.when(jnp.logical_and(warp_id == MMA_WARP, local_index > 1))
129129
def _wait_store():
130130
plgpu.barrier_wait(store_done_barrier.at[acc_slot])
131131
@pl.when(jnp.logical_and(warp_id == MMA_WARP, is_lead_block))
@@ -297,10 +297,10 @@ def kernel(a_gmem, b_gmem, group_sizes_gmem, out_gmem):
297297
)
298298
def _scoped(**ref_kwargs):
299299
@plgpu.nd_loop(grid=(linear_grid,),
300-
collective_axes="sm",
301-
include_wave_step=True)
302-
def mn_loop(idx, wave_step): # pylint: disable=unused-variable
303-
linear_idx, = idx
300+
collective_axes="sm")
301+
def mn_loop(loop_info: plgpu.NDLoopInfo): # pylint: disable=unused-variable
302+
linear_idx, = loop_info.index
303+
local_index = loop_info.local_index # type: ignore
304304
m_index, n_index = plgpu.planar_snake(
305305
linear_idx,
306306
(m_iters + num_groups - 1, n_iters),
@@ -318,7 +318,7 @@ def mn_loop(idx, wave_step): # pylint: disable=unused-variable
318318
grid_indices=(group_info.block, n_index, cluster_idx),
319319
wg_axis="wg",
320320
collective_axes=("x",) if collective else (),
321-
wave_step=wave_step,
321+
local_index=local_index, # type: ignore
322322
config=config,
323323
group_info=group_info,
324324
**ref_kwargs

jax/experimental/pallas/ops/gpu/collective_matmul_mgpu.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,8 @@ def get_pipeline(pipeline_body, compute_context):
143143
)
144144
def _pipeline_scope(pipeline_allocs):
145145
@plgpu.nd_loop((m_iters * n_iters,), collective_axes="sm")
146-
def _mn_loop(idxs):
147-
(lin_idx,) = idxs
146+
def _mn_loop(loop_info: plgpu.NDLoopInfo):
147+
(lin_idx,) = loop_info.index
148148
m_idx, n_idx = plgpu.planar_snake(
149149
lin_idx,
150150
(m_iters, n_iters),
@@ -158,7 +158,7 @@ def _mn_loop(idxs):
158158
wg_n_slice = slice(None)
159159
else:
160160
wg_m_slice = slice(None)
161-
wg_n_slice = pl.ds(wg_idx * tile_n, tile_n)
161+
wg_n_slice = pl.ds(wg_idx * tile_n, tile_n) # type: ignore
162162

163163
def compute_context(eval_pipeline):
164164
@functools.partial(

jax/experimental/pallas/ops/gpu/hopper_matmul_mgpu.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,8 @@ def _pipeline_scope(pipeline_allocs):
138138
wg_idx = lax.axis_index("wg")
139139
cta_idx = lax.axis_index("cluster")
140140
@plgpu.nd_loop((m_iters * n_iters,), collective_axes="cluster_grid")
141-
def _mn_loop(idxs):
142-
(lin_idx,) = idxs
141+
def _mn_loop(loop_info: plgpu.NDLoopInfo):
142+
(lin_idx,) = loop_info.index
143143
m_cluster_idx, n_cluster_idx = plgpu.planar_snake(
144144
lin_idx,
145145
(m_iters, n_iters),
@@ -159,7 +159,7 @@ def _mn_loop(idxs):
159159
wg_n_slice = slice(None)
160160
else:
161161
wg_m_slice = slice(None)
162-
wg_n_slice = pl.ds(wg_idx * tile_n, tile_n)
162+
wg_n_slice = pl.ds(wg_idx * tile_n, tile_n) # type: ignore
163163

164164
def compute_context(eval_pipeline):
165165
@functools.partial(

jax/experimental/pallas/ops/gpu/ragged_dot_mgpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ def body(rows_per_expert_gmem, lhs_gmem, rhs_gmem, o_gmem):
126126
)
127127

128128
@plgpu.nd_loop(grid, collective_axes="sm")
129-
def mn_loop(idx): # pylint: disable=unused-variable
130-
block_ni, mi, remainder_ni = idx
129+
def mn_loop(loop_info: plgpu.NDLoopInfo): # pylint: disable=unused-variable
130+
block_ni, mi, remainder_ni = loop_info.index
131131
ni = block_ni * pl.cdiv(n, block_n * grid_block_n) + remainder_ni
132132
group_info = GroupInfo.create(rows_per_expert_gmem, block_m, mi)
133133

tests/pallas/mosaic_gpu_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1979,7 +1979,8 @@ def test_nd_loop_with_carry(self, sm_steps):
19791979
grid_names=("sm",),
19801980
)
19811981
def kernel(o_ref, steps_ref):
1982-
def body(idx, carry):
1982+
def body(loop_info, carry):
1983+
idx = loop_info.index
19831984
assert len(idx) == 3
19841985
# We need to use `mode="clip"`, because the indices are not static.
19851986
flat_idx = jnp.ravel_multi_index(idx, (sm_steps, 4, 33), mode="clip")
@@ -2022,7 +2023,8 @@ def test_nd_loop(self, sm_steps: int, tiling: int | None):
20222023
)
20232024
def kernel(o_ref):
20242025
@plgpu.nd_loop((sm_steps, 4, 33), tiling=tiling, collective_axes="sm")
2025-
def _(idx):
2026+
def _(loop_info):
2027+
idx = loop_info.index
20262028
assert len(idx) == 3
20272029
# We need to use `mode="clip"`, because the indices are not static.
20282030
grid = (sm_steps, 4, 33)

0 commit comments

Comments
 (0)