Skip to content

Commit 9fd571d

Browse files
committed
Implement ragged gather TC Pallas kernel to shard computation across devices
1 parent ac548a6 commit 9fd571d

File tree

3 files changed

+431
-4
lines changed

3 files changed

+431
-4
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 361 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,361 @@
1+
"""TensorCore-based Pallas ragged gather kernel."""
2+
3+
import dataclasses
4+
import functools
5+
6+
import jax
7+
from jax import numpy as jnp
8+
from jax import tree_util
9+
from jax._src.pallas.mosaic import pipeline
10+
from jax.experimental import pallas as pl
11+
from jax.experimental.pallas import tpu as pltpu
12+
13+
_NUM_BUFFERS = 2
14+
15+
16+
@tree_util.register_dataclass
17+
@dataclasses.dataclass(frozen=True)
18+
class GatherBufferedRef(pipeline.BufferedRef):
19+
"""Custom BufferedRef managing async DMA for gathering operations.
20+
21+
Overrides `copy_in` and `wait_in` to break standard contiguous block
22+
iteration. Instead, it dynamically orchestrates DMA transfers using
23+
`idx_aligned_ref` to fetch 8-element contiguous segments from HBM. The fetched
24+
data sits in a `(block_size, 8, hidden_dim)` VMEM scratch buffer.
25+
"""
26+
27+
block_size: int = dataclasses.field(metadata={"static": True}, default=0)
28+
29+
@classmethod
30+
def create(
31+
cls,
32+
spec: pl.BlockSpec,
33+
source_array: jax.Array,
34+
block_size: int,
35+
buffer_count: int = _NUM_BUFFERS,
36+
):
37+
standard_ref = pipeline.BufferedRef.create(
38+
spec=spec,
39+
dtype_or_type=pipeline._ref_to_value_aval(source_array),
40+
buffer_type=pipeline.BufferType.INPUT,
41+
buffer_count=buffer_count,
42+
grid_rank=1,
43+
source_memory_space=pltpu.HBM,
44+
)
45+
return cls.from_ref(
46+
standard_ref,
47+
block_size=block_size,
48+
)
49+
50+
@classmethod
51+
def from_ref(
52+
cls,
53+
ref: pipeline.BufferedRef,
54+
*,
55+
block_size: int = 0,
56+
):
57+
return cls(
58+
block_size=block_size,
59+
**{
60+
f.name: getattr(ref, f.name)
61+
for f in dataclasses.fields(pipeline.BufferedRef)
62+
},
63+
)
64+
65+
def copy_in(self, src_ref, grid_indices):
66+
x_hbm_ref, idx_aligned_ref, aligned_start_ref = src_ref
67+
slot = self.current_copy_in_slot
68+
block_idx = grid_indices[0]
69+
70+
global_block_start = aligned_start_ref[0] + block_idx * self.block_size
71+
72+
for i in range(self.block_size):
73+
global_token_idx = global_block_start + i
74+
idx_aligned = pl.multiple_of(idx_aligned_ref[global_token_idx], 8)
75+
76+
assert self.sem_recvs is not None
77+
pltpu.make_async_copy(
78+
x_hbm_ref.at[pl.ds(idx_aligned, 8), :],
79+
self.window_ref.at[slot, i, :, :],
80+
self.sem_recvs.at[slot],
81+
).start()
82+
83+
def wait_in(self, src_ref, grid_indices):
84+
wait_slot = self.current_wait_in_slot
85+
86+
assert self.sem_recvs is not None
87+
pltpu.make_async_copy(
88+
self.window_ref.at[wait_slot, : self.block_size, :, :],
89+
self.window_ref.at[wait_slot, : self.block_size, :, :],
90+
self.sem_recvs.at[wait_slot],
91+
).wait()
92+
93+
94+
def inner_kernel(
95+
block_size: int,
96+
aligned_start_ref,
97+
end_idx_ref,
98+
local_start_ref,
99+
idx_mod_8_ref,
100+
x_vmem,
101+
o_vmem,
102+
):
103+
"""Inner kernel to perform the actual gather operation for a single block.
104+
105+
Args:
106+
block_size: The number of elements to process per block.
107+
aligned_start_ref: The absolute start index, rounded down to the nearest
108+
multiple of `block_size`.
109+
end_idx_ref: The absolute end index. Used to mask out-of-bounds calculations
110+
in the final execution block.
111+
local_start_ref: Offset (0-7) between the true `start_idx` and
112+
`aligned_start_ref`. Used to mask out invalid elements inside the very
113+
first execution block.
114+
idx_mod_8_ref: A tensor of shape (total_indices + block_size,) that contains
115+
`indices & 7`, providing the local sub-row offsets within the 8-element
116+
chunks physically loaded from HBM.
117+
x_vmem: 8-element chunks of prefetched data in VMEM with shape (block_size,
118+
8, hidden_dim).
119+
o_vmem: Output tensor of shape (block_size, hidden_dim).
120+
"""
121+
block_idx = pl.program_id(0)
122+
123+
def _inner_kernel(is_first_block: bool, is_last_block: bool):
124+
global_block_start = (
125+
pl.multiple_of(aligned_start_ref[0], 8) + block_idx * block_size
126+
)
127+
128+
local_start = local_start_ref[0]
129+
local_end = end_idx_ref[0] - global_block_start
130+
131+
for i in range(block_size):
132+
global_token_idx = global_block_start + i
133+
mod_8 = idx_mod_8_ref[global_token_idx]
134+
row_indices = jnp.broadcast_to(mod_8, (8, 128)).astype(jnp.int32)
135+
136+
if is_first_block and is_last_block:
137+
is_valid_mask = (i >= local_start) & (i < local_end)
138+
elif is_first_block:
139+
is_valid_mask = i >= local_start
140+
elif is_last_block:
141+
is_valid_mask = i < local_end
142+
else:
143+
is_valid_mask = None
144+
145+
# Iterate over 128-width chunks of hidden dims to reuse row_indices.
146+
hidden_dim = x_vmem.shape[-1]
147+
for c in range(0, hidden_dim, 128):
148+
cols = pl.ds(c, 128)
149+
extracted = jnp.take_along_axis(
150+
x_vmem[i, :, cols].astype(jnp.float32),
151+
row_indices,
152+
axis=0,
153+
)
154+
if is_valid_mask is not None:
155+
result = jnp.where(
156+
is_valid_mask, extracted[0], jnp.zeros_like(extracted[0])
157+
)
158+
else:
159+
result = extracted[0]
160+
o_vmem[i, cols] = result.astype(o_vmem.dtype)
161+
162+
@jax.named_scope("gather_first_last")
163+
def gather_first_last():
164+
_inner_kernel(is_first_block=True, is_last_block=True)
165+
166+
@jax.named_scope("gather_first")
167+
def gather_first():
168+
_inner_kernel(is_first_block=True, is_last_block=False)
169+
170+
@jax.named_scope("gather")
171+
def gather():
172+
_inner_kernel(is_first_block=False, is_last_block=False)
173+
174+
@jax.named_scope("gather_last")
175+
def gather_last():
176+
_inner_kernel(is_first_block=False, is_last_block=True)
177+
178+
is_first_block = block_idx == 0
179+
is_last_block = block_idx == (pl.num_programs(0) - 1)
180+
181+
jax.lax.cond(
182+
is_first_block,
183+
lambda: jax.lax.cond(
184+
is_last_block,
185+
gather_first_last,
186+
gather_first,
187+
),
188+
lambda: jax.lax.cond(
189+
is_last_block,
190+
gather_last,
191+
gather,
192+
),
193+
)
194+
195+
196+
def tensorcore_gather(
197+
x: jax.Array,
198+
indices: jax.Array,
199+
start_idx: int | jax.Array | None = None,
200+
end_idx: int | jax.Array | None = None,
201+
block_size: int = 32,
202+
) -> jax.Array:
203+
"""Gathers a range of tokens from x using TensorCore."""
204+
assert (
205+
block_size % 8 == 0
206+
), f"block_size must be divisible by 8, got {block_size}"
207+
total_indices = indices.shape[0]
208+
hidden_dim = x.shape[1]
209+
dtype = x.dtype
210+
211+
if start_idx is None:
212+
start_idx = 0
213+
if end_idx is None:
214+
end_idx = total_indices
215+
216+
if total_indices % block_size != 0:
217+
raise ValueError(
218+
f"total_indices ({total_indices}) must be a multiple of block_size"
219+
f" ({block_size})."
220+
)
221+
222+
aligned_start = (start_idx // block_size) * block_size
223+
aligned_end = pl.cdiv(end_idx, block_size) * block_size
224+
num_blocks = pl.cdiv(aligned_end - aligned_start, block_size)
225+
local_start = start_idx - aligned_start
226+
227+
idx_aligned_padded = jnp.pad(indices & ~7, (0, block_size))
228+
idx_mod_8_padded = jnp.pad(indices & 7, (0, block_size))
229+
230+
@jax.named_scope("tensorcore_gather_kernel")
231+
def gather_kernel(
232+
num_blocks_ref,
233+
aligned_start_ref,
234+
end_idx_ref,
235+
local_start_ref,
236+
idx_aligned_ref,
237+
idx_mod_8_ref,
238+
x_hbm_ref,
239+
o_hbm_ref,
240+
):
241+
"""Executes the Gather pipeline over a perfectly tiled local execution grid.
242+
243+
Args:
244+
num_blocks_ref: Scalar value of the number of blocks to process.
245+
aligned_start_ref: The absolute start index, rounded down to the nearest
246+
multiple of `block_size`.
247+
end_idx_ref: The absolute end index. Used to mask out-of-bounds
248+
calculations in the final execution block.
249+
local_start_ref: Offset (0-7) between the true `start_idx` and
250+
`aligned_start_ref`. Used to mask out invalid elements inside the very
251+
first execution block.
252+
idx_aligned_ref: A tensor of shape (total_indices + block_size,) that
253+
contains `indices & ~7`. Used to dispatch aligned HBM fetches for each
254+
token.
255+
idx_mod_8_ref: A tensor of shape (total_indices + block_size,) that
256+
contains `indices & 7`, providing the local sub-row offsets within the
257+
8-element chunks physically loaded from HBM.
258+
x_hbm_ref: The input tensor referenced in HBM logic.
259+
o_hbm_ref: The output tensor referenced in HBM logic.
260+
"""
261+
inner_kernel_partial = functools.partial(
262+
inner_kernel,
263+
block_size,
264+
aligned_start_ref,
265+
end_idx_ref,
266+
local_start_ref,
267+
idx_mod_8_ref,
268+
)
269+
270+
_in_specs = [
271+
pl.BlockSpec(
272+
index_map=lambda *idx: idx,
273+
memory_space=pltpu.VMEM,
274+
block_shape=(block_size, 8, hidden_dim),
275+
),
276+
]
277+
278+
def o_index_map(i):
279+
start_block_idx = aligned_start_ref[0] // block_size
280+
return (start_block_idx + i, 0)
281+
282+
_out_specs = [
283+
pl.BlockSpec(
284+
index_map=o_index_map,
285+
memory_space=pltpu.VMEM,
286+
block_shape=(block_size, hidden_dim),
287+
pipeline_mode=pl.Buffered(buffer_count=_NUM_BUFFERS),
288+
),
289+
]
290+
291+
pipeline_func = pipeline.emit_pipeline(
292+
inner_kernel_partial,
293+
grid=(num_blocks_ref[0],),
294+
in_specs=_in_specs,
295+
out_specs=_out_specs,
296+
)
297+
298+
x_alloc = GatherBufferedRef.create(
299+
spec=_in_specs[0],
300+
source_array=x_hbm_ref,
301+
block_size=block_size,
302+
)
303+
304+
o_alloc = pipeline.BufferedRef.create(
305+
spec=_out_specs[0],
306+
dtype_or_type=pipeline._ref_to_value_aval(o_hbm_ref),
307+
buffer_type=pipeline.BufferType.OUTPUT,
308+
buffer_count=_NUM_BUFFERS,
309+
grid_rank=1,
310+
source_memory_space=pltpu.HBM,
311+
)
312+
313+
def _run(allocs):
314+
pipeline_func(
315+
(x_hbm_ref, idx_aligned_ref, aligned_start_ref),
316+
o_hbm_ref,
317+
allocations=allocs,
318+
)
319+
320+
pl.run_scoped(_run, (x_alloc, o_alloc))
321+
322+
x = pltpu.with_memory_space_constraint(x, pltpu.HBM)
323+
grid_spec = pltpu.PrefetchScalarGridSpec(
324+
num_scalar_prefetch=6,
325+
in_specs=[
326+
pl.BlockSpec(
327+
memory_space=pltpu.HBM,
328+
pipeline_mode=pl.Buffered(buffer_count=_NUM_BUFFERS),
329+
),
330+
],
331+
out_specs=pl.BlockSpec(
332+
memory_space=pltpu.HBM,
333+
pipeline_mode=pl.Buffered(buffer_count=_NUM_BUFFERS),
334+
),
335+
scratch_shapes=[],
336+
)
337+
to_arr = lambda x: jnp.array([x], dtype=jnp.int32)
338+
339+
res = pl.pallas_call(
340+
gather_kernel,
341+
out_shape=jax.ShapeDtypeStruct((total_indices, hidden_dim), dtype),
342+
grid_spec=grid_spec,
343+
name=f"tc_gather_hidden{hidden_dim}_numidx{total_indices}_block{block_size}",
344+
metadata={
345+
"block_size": str(block_size),
346+
"hidden_dim": str(hidden_dim),
347+
"total_indices": str(total_indices),
348+
"dtype": str(dtype),
349+
"num_buffers": str(_NUM_BUFFERS),
350+
},
351+
)(
352+
to_arr(num_blocks),
353+
to_arr(aligned_start),
354+
to_arr(end_idx),
355+
to_arr(local_start),
356+
idx_aligned_padded,
357+
idx_mod_8_padded,
358+
x,
359+
)
360+
361+
return res

0 commit comments

Comments
 (0)