1515"""Helpers for Pallas Mosaic GPU kernels."""
1616
1717from collections .abc import Callable , Hashable , Sequence
18+ import dataclasses
1819import functools
1920import math
2021from typing import TypeVar , overload
2728_T = TypeVar ("_T" )
2829
2930
30- @overload
31- def nd_loop (
32- grid : Sequence [int ],
33- * ,
34- collective_axes : Sequence [Hashable ] | Hashable ,
35- tiling : Sequence [int ] | None = None ,
36- init_carry : None = None
37- ) -> Callable [[Callable [[Sequence [jax .Array ]], None ]], None ]:
38- ...
31+ @dataclasses .dataclass (frozen = True , eq = False )
32+ class NDLoopInfo :
33+ """Container dataclass for loop iteration information.
34+
35+ Attributes:
36+ index: The grid indices corresponding to the current loop iteration.
37+ local_index: The local iteration index.
38+ num_local_steps: The total number of local iterations to run.
39+ """
40+ index : tuple [jax .Array , ...]
41+ local_index : jax .Array | int
42+ num_local_steps : jax .Array | int
3943
4044
4145@overload
@@ -44,29 +48,30 @@ def nd_loop(
4448 * ,
4549 collective_axes : Sequence [Hashable ] | Hashable ,
4650 tiling : Sequence [int ] | None = None ,
47- init_carry : _T
48- ) -> Callable [[Callable [[Sequence [ jax . Array ], _T ], _T ]], _T ]:
51+ init_carry : None = None
52+ ) -> Callable [[Callable [[NDLoopInfo ], None ]], None ]:
4953 ...
5054
5155
52- # TODO(justinfu): Fix the type signature to include both carry and wave_step.
5356@overload
5457def nd_loop (
5558 grid : Sequence [int ],
5659 * ,
5760 collective_axes : Sequence [Hashable ] | Hashable ,
5861 tiling : Sequence [int ] | None = None ,
59- include_wave_step : bool
60- ) -> Callable [[Callable [[Sequence [ jax . Array ], jax . Array ], None ]], None ]:
62+ init_carry : _T
63+ ) -> Callable [[Callable [[NDLoopInfo , _T ], _T ]], _T ]:
6164 ...
6265
6366
64- def nd_loop (grid , * , collective_axes ,
65- tiling = None ,
66- init_carry = None ,
67- include_wave_step = False ):
67+ def nd_loop (grid , * , collective_axes , tiling = None , init_carry = None ):
6868 """A loop over a multi-dimensional grid partitioned along the given axes.
6969
70+ The body of the loop a single argument `loop_info` which is an NDLoopInfo
71+ object containing index and iteration information. However if a carry is
72+ specified, the body will expect a second keyword argument `carry` containing
73+ the loop carry.
74+
7075 For example, if ``collective_axes`` is ``"x"`` with :func:`lax.axis_size`
7176 equal to 4 and the grid is (2, 3), the implementation would produce the
7277 following iteration order
@@ -98,10 +103,6 @@ def nd_loop(grid, *, collective_axes,
98103 take and return the carry. If it's ``None`` then no carry argument is
99104 expected.
100105
101- If ``include_wave_step`` is True then the body will be called with an
102- additional ``wave_step`` keyword argument that specifies the current
103- iteration local to the thread.
104-
105106 See also:
106107 - :func:`jax.experimental.pallas.loop`: A loop over a single dimension.
107108 """
@@ -141,12 +142,15 @@ def wrapper(wave_step, carry):
141142 untiled_index .append (sub_idx + tile_idx * tile_dim )
142143 index = untiled_index
143144
144- if include_wave_step :
145- body = functools .partial (body , wave_step = wave_step )
145+ loop_info = NDLoopInfo (
146+ index = tuple (index ),
147+ local_index = wave_step ,
148+ num_local_steps = upper
149+ )
146150 if init_carry is None :
147- body (tuple ( index ) )
151+ body (loop_info )
148152 else :
149- return body (tuple ( index ) , carry = carry )
153+ return body (loop_info , carry = carry )
150154
151155 upper = lax .div (grid_size , axis_size ) + lax .convert_element_type (
152156 axis_index < grid_size % axis_size , axis_index .dtype
0 commit comments