|
| 1 | +"""TensorCore-based Pallas ragged gather kernel.""" |
| 2 | + |
| 3 | +import dataclasses |
| 4 | +import functools |
| 5 | + |
| 6 | +import jax |
| 7 | +from jax import numpy as jnp |
| 8 | +from jax import tree_util |
| 9 | +from jax._src.pallas.mosaic import pipeline |
| 10 | +from jax.experimental import pallas as pl |
| 11 | +from jax.experimental.pallas import tpu as pltpu |
| 12 | + |
| 13 | +_NUM_BUFFERS = 2 |
| 14 | + |
| 15 | + |
| 16 | +@tree_util.register_dataclass |
| 17 | +@dataclasses.dataclass(frozen=True) |
| 18 | +class GatherBufferedRef(pipeline.BufferedRef): |
| 19 | + """Custom BufferedRef managing async DMA for gathering operations. |
| 20 | +
|
| 21 | + Overrides `copy_in` and `wait_in` to break standard contiguous block |
| 22 | + iteration. Instead, it dynamically orchestrates DMA transfers using |
| 23 | + `idx_aligned_ref` to fetch 8-element contiguous segments from HBM. The fetched |
| 24 | + data sits in a `(block_size, 8, hidden_dim)` VMEM scratch buffer. |
| 25 | + """ |
| 26 | + |
| 27 | + block_size: int = dataclasses.field(metadata={"static": True}, default=0) |
| 28 | + |
| 29 | + @classmethod |
| 30 | + def create( |
| 31 | + cls, |
| 32 | + spec: pl.BlockSpec, |
| 33 | + source_array: jax.Array, |
| 34 | + block_size: int, |
| 35 | + buffer_count: int = _NUM_BUFFERS, |
| 36 | + ): |
| 37 | + standard_ref = pipeline.BufferedRef.create( |
| 38 | + spec=spec, |
| 39 | + dtype_or_type=pipeline._ref_to_value_aval(source_array), |
| 40 | + buffer_type=pipeline.BufferType.INPUT, |
| 41 | + buffer_count=buffer_count, |
| 42 | + grid_rank=1, |
| 43 | + source_memory_space=pltpu.HBM, |
| 44 | + ) |
| 45 | + return cls.from_ref( |
| 46 | + standard_ref, |
| 47 | + block_size=block_size, |
| 48 | + ) |
| 49 | + |
| 50 | + @classmethod |
| 51 | + def from_ref( |
| 52 | + cls, |
| 53 | + ref: pipeline.BufferedRef, |
| 54 | + *, |
| 55 | + block_size: int = 0, |
| 56 | + ): |
| 57 | + return cls( |
| 58 | + block_size=block_size, |
| 59 | + **{ |
| 60 | + f.name: getattr(ref, f.name) |
| 61 | + for f in dataclasses.fields(pipeline.BufferedRef) |
| 62 | + }, |
| 63 | + ) |
| 64 | + |
| 65 | + def copy_in(self, src_ref, grid_indices): |
| 66 | + x_hbm_ref, idx_aligned_ref, aligned_start_ref = src_ref |
| 67 | + slot = self.current_copy_in_slot |
| 68 | + block_idx = grid_indices[0] |
| 69 | + |
| 70 | + global_block_start = aligned_start_ref[0] + block_idx * self.block_size |
| 71 | + |
| 72 | + for i in range(self.block_size): |
| 73 | + global_token_idx = global_block_start + i |
| 74 | + idx_aligned = pl.multiple_of(idx_aligned_ref[global_token_idx], 8) |
| 75 | + |
| 76 | + assert self.sem_recvs is not None |
| 77 | + pltpu.make_async_copy( |
| 78 | + x_hbm_ref.at[pl.ds(idx_aligned, 8), :], |
| 79 | + self.window_ref.at[slot, i, :, :], |
| 80 | + self.sem_recvs.at[slot], |
| 81 | + ).start() |
| 82 | + |
| 83 | + def wait_in(self, src_ref, grid_indices): |
| 84 | + wait_slot = self.current_wait_in_slot |
| 85 | + |
| 86 | + assert self.sem_recvs is not None |
| 87 | + pltpu.make_async_copy( |
| 88 | + self.window_ref.at[wait_slot, : self.block_size, :, :], |
| 89 | + self.window_ref.at[wait_slot, : self.block_size, :, :], |
| 90 | + self.sem_recvs.at[wait_slot], |
| 91 | + ).wait() |
| 92 | + |
| 93 | + |
| 94 | +def inner_kernel( |
| 95 | + block_size: int, |
| 96 | + aligned_start_ref, |
| 97 | + end_idx_ref, |
| 98 | + local_start_ref, |
| 99 | + idx_mod_8_ref, |
| 100 | + x_vmem, |
| 101 | + o_vmem, |
| 102 | +): |
| 103 | + """Inner kernel to perform the actual gather operation for a single block. |
| 104 | +
|
| 105 | + Args: |
| 106 | + block_size: The number of elements to process per block. |
| 107 | + aligned_start_ref: The absolute start index, rounded down to the nearest |
| 108 | + multiple of `block_size`. |
| 109 | + end_idx_ref: The absolute end index. Used to mask out-of-bounds calculations |
| 110 | + in the final execution block. |
| 111 | + local_start_ref: Offset (0-7) between the true `start_idx` and |
| 112 | + `aligned_start_ref`. Used to mask out invalid elements inside the very |
| 113 | + first execution block. |
| 114 | + idx_mod_8_ref: A tensor of shape (total_indices + block_size,) that contains |
| 115 | + `indices & 7`, providing the local sub-row offsets within the 8-element |
| 116 | + chunks physically loaded from HBM. |
| 117 | + x_vmem: 8-element chunks of prefetched data in VMEM with shape (block_size, |
| 118 | + 8, hidden_dim). |
| 119 | + o_vmem: Output tensor of shape (block_size, hidden_dim). |
| 120 | + """ |
| 121 | + block_idx = pl.program_id(0) |
| 122 | + |
| 123 | + def _inner_kernel(is_first_block: bool, is_last_block: bool): |
| 124 | + global_block_start = ( |
| 125 | + pl.multiple_of(aligned_start_ref[0], 8) + block_idx * block_size |
| 126 | + ) |
| 127 | + |
| 128 | + local_start = local_start_ref[0] |
| 129 | + local_end = end_idx_ref[0] - global_block_start |
| 130 | + |
| 131 | + for i in range(block_size): |
| 132 | + global_token_idx = global_block_start + i |
| 133 | + mod_8 = idx_mod_8_ref[global_token_idx] |
| 134 | + row_indices = jnp.broadcast_to(mod_8, (8, 128)).astype(jnp.int32) |
| 135 | + |
| 136 | + if is_first_block and is_last_block: |
| 137 | + is_valid_mask = (i >= local_start) & (i < local_end) |
| 138 | + elif is_first_block: |
| 139 | + is_valid_mask = i >= local_start |
| 140 | + elif is_last_block: |
| 141 | + is_valid_mask = i < local_end |
| 142 | + else: |
| 143 | + is_valid_mask = None |
| 144 | + |
| 145 | + # Iterate over 128-width chunks of hidden dims to reuse row_indices. |
| 146 | + hidden_dim = x_vmem.shape[-1] |
| 147 | + for c in range(0, hidden_dim, 128): |
| 148 | + cols = pl.ds(c, 128) |
| 149 | + extracted = jnp.take_along_axis( |
| 150 | + x_vmem[i, :, cols].astype(jnp.float32), |
| 151 | + row_indices, |
| 152 | + axis=0, |
| 153 | + ) |
| 154 | + if is_valid_mask is not None: |
| 155 | + result = jnp.where( |
| 156 | + is_valid_mask, extracted[0], jnp.zeros_like(extracted[0]) |
| 157 | + ) |
| 158 | + else: |
| 159 | + result = extracted[0] |
| 160 | + o_vmem[i, cols] = result.astype(o_vmem.dtype) |
| 161 | + |
| 162 | + @jax.named_scope("gather_first_last") |
| 163 | + def gather_first_last(): |
| 164 | + _inner_kernel(is_first_block=True, is_last_block=True) |
| 165 | + |
| 166 | + @jax.named_scope("gather_first") |
| 167 | + def gather_first(): |
| 168 | + _inner_kernel(is_first_block=True, is_last_block=False) |
| 169 | + |
| 170 | + @jax.named_scope("gather") |
| 171 | + def gather(): |
| 172 | + _inner_kernel(is_first_block=False, is_last_block=False) |
| 173 | + |
| 174 | + @jax.named_scope("gather_last") |
| 175 | + def gather_last(): |
| 176 | + _inner_kernel(is_first_block=False, is_last_block=True) |
| 177 | + |
| 178 | + is_first_block = block_idx == 0 |
| 179 | + is_last_block = block_idx == (pl.num_programs(0) - 1) |
| 180 | + |
| 181 | + jax.lax.cond( |
| 182 | + is_first_block, |
| 183 | + lambda: jax.lax.cond( |
| 184 | + is_last_block, |
| 185 | + gather_first_last, |
| 186 | + gather_first, |
| 187 | + ), |
| 188 | + lambda: jax.lax.cond( |
| 189 | + is_last_block, |
| 190 | + gather_last, |
| 191 | + gather, |
| 192 | + ), |
| 193 | + ) |
| 194 | + |
| 195 | + |
| 196 | +def tensorcore_gather( |
| 197 | + x: jax.Array, |
| 198 | + indices: jax.Array, |
| 199 | + start_idx: int | jax.Array | None = None, |
| 200 | + end_idx: int | jax.Array | None = None, |
| 201 | + block_size: int = 32, |
| 202 | +) -> jax.Array: |
| 203 | + """Gathers a range of tokens from x using TensorCore.""" |
| 204 | + assert ( |
| 205 | + block_size % 8 == 0 |
| 206 | + ), f"block_size must be divisible by 8, got {block_size}" |
| 207 | + total_indices = indices.shape[0] |
| 208 | + hidden_dim = x.shape[1] |
| 209 | + dtype = x.dtype |
| 210 | + |
| 211 | + if start_idx is None: |
| 212 | + start_idx = 0 |
| 213 | + if end_idx is None: |
| 214 | + end_idx = total_indices |
| 215 | + |
| 216 | + if total_indices % block_size != 0: |
| 217 | + raise ValueError( |
| 218 | + f"total_indices ({total_indices}) must be a multiple of block_size" |
| 219 | + f" ({block_size})." |
| 220 | + ) |
| 221 | + |
| 222 | + aligned_start = (start_idx // block_size) * block_size |
| 223 | + aligned_end = pl.cdiv(end_idx, block_size) * block_size |
| 224 | + num_blocks = pl.cdiv(aligned_end - aligned_start, block_size) |
| 225 | + local_start = start_idx - aligned_start |
| 226 | + |
| 227 | + idx_aligned_padded = jnp.pad(indices & ~7, (0, block_size)) |
| 228 | + idx_mod_8_padded = jnp.pad(indices & 7, (0, block_size)) |
| 229 | + |
| 230 | + @jax.named_scope("tensorcore_gather_kernel") |
| 231 | + def gather_kernel( |
| 232 | + num_blocks_ref, |
| 233 | + aligned_start_ref, |
| 234 | + end_idx_ref, |
| 235 | + local_start_ref, |
| 236 | + idx_aligned_ref, |
| 237 | + idx_mod_8_ref, |
| 238 | + x_hbm_ref, |
| 239 | + o_hbm_ref, |
| 240 | + ): |
| 241 | + """Executes the Gather pipeline over a perfectly tiled local execution grid. |
| 242 | +
|
| 243 | + Args: |
| 244 | + num_blocks_ref: Scalar value of the number of blocks to process. |
| 245 | + aligned_start_ref: The absolute start index, rounded down to the nearest |
| 246 | + multiple of `block_size`. |
| 247 | + end_idx_ref: The absolute end index. Used to mask out-of-bounds |
| 248 | + calculations in the final execution block. |
| 249 | + local_start_ref: Offset (0-7) between the true `start_idx` and |
| 250 | + `aligned_start_ref`. Used to mask out invalid elements inside the very |
| 251 | + first execution block. |
| 252 | + idx_aligned_ref: A tensor of shape (total_indices + block_size,) that |
| 253 | + contains `indices & ~7`. Used to dispatch aligned HBM fetches for each |
| 254 | + token. |
| 255 | + idx_mod_8_ref: A tensor of shape (total_indices + block_size,) that |
| 256 | + contains `indices & 7`, providing the local sub-row offsets within the |
| 257 | + 8-element chunks physically loaded from HBM. |
| 258 | + x_hbm_ref: The input tensor referenced in HBM logic. |
| 259 | + o_hbm_ref: The output tensor referenced in HBM logic. |
| 260 | + """ |
| 261 | + inner_kernel_partial = functools.partial( |
| 262 | + inner_kernel, |
| 263 | + block_size, |
| 264 | + aligned_start_ref, |
| 265 | + end_idx_ref, |
| 266 | + local_start_ref, |
| 267 | + idx_mod_8_ref, |
| 268 | + ) |
| 269 | + |
| 270 | + _in_specs = [ |
| 271 | + pl.BlockSpec( |
| 272 | + index_map=lambda *idx: idx, |
| 273 | + memory_space=pltpu.VMEM, |
| 274 | + block_shape=(block_size, 8, hidden_dim), |
| 275 | + ), |
| 276 | + ] |
| 277 | + |
| 278 | + def o_index_map(i): |
| 279 | + start_block_idx = aligned_start_ref[0] // block_size |
| 280 | + return (start_block_idx + i, 0) |
| 281 | + |
| 282 | + _out_specs = [ |
| 283 | + pl.BlockSpec( |
| 284 | + index_map=o_index_map, |
| 285 | + memory_space=pltpu.VMEM, |
| 286 | + block_shape=(block_size, hidden_dim), |
| 287 | + pipeline_mode=pl.Buffered(buffer_count=_NUM_BUFFERS), |
| 288 | + ), |
| 289 | + ] |
| 290 | + |
| 291 | + pipeline_func = pipeline.emit_pipeline( |
| 292 | + inner_kernel_partial, |
| 293 | + grid=(num_blocks_ref[0],), |
| 294 | + in_specs=_in_specs, |
| 295 | + out_specs=_out_specs, |
| 296 | + ) |
| 297 | + |
| 298 | + x_alloc = GatherBufferedRef.create( |
| 299 | + spec=_in_specs[0], |
| 300 | + source_array=x_hbm_ref, |
| 301 | + block_size=block_size, |
| 302 | + ) |
| 303 | + |
| 304 | + o_alloc = pipeline.BufferedRef.create( |
| 305 | + spec=_out_specs[0], |
| 306 | + dtype_or_type=pipeline._ref_to_value_aval(o_hbm_ref), |
| 307 | + buffer_type=pipeline.BufferType.OUTPUT, |
| 308 | + buffer_count=_NUM_BUFFERS, |
| 309 | + grid_rank=1, |
| 310 | + source_memory_space=pltpu.HBM, |
| 311 | + ) |
| 312 | + |
| 313 | + def _run(allocs): |
| 314 | + pipeline_func( |
| 315 | + (x_hbm_ref, idx_aligned_ref, aligned_start_ref), |
| 316 | + o_hbm_ref, |
| 317 | + allocations=allocs, |
| 318 | + ) |
| 319 | + |
| 320 | + pl.run_scoped(_run, (x_alloc, o_alloc)) |
| 321 | + |
| 322 | + x = pltpu.with_memory_space_constraint(x, pltpu.HBM) |
| 323 | + grid_spec = pltpu.PrefetchScalarGridSpec( |
| 324 | + num_scalar_prefetch=6, |
| 325 | + in_specs=[ |
| 326 | + pl.BlockSpec( |
| 327 | + memory_space=pltpu.HBM, |
| 328 | + pipeline_mode=pl.Buffered(buffer_count=_NUM_BUFFERS), |
| 329 | + ), |
| 330 | + ], |
| 331 | + out_specs=pl.BlockSpec( |
| 332 | + memory_space=pltpu.HBM, |
| 333 | + pipeline_mode=pl.Buffered(buffer_count=_NUM_BUFFERS), |
| 334 | + ), |
| 335 | + scratch_shapes=[], |
| 336 | + ) |
| 337 | + to_arr = lambda x: jnp.array([x], dtype=jnp.int32) |
| 338 | + |
| 339 | + res = pl.pallas_call( |
| 340 | + gather_kernel, |
| 341 | + out_shape=jax.ShapeDtypeStruct((total_indices, hidden_dim), dtype), |
| 342 | + grid_spec=grid_spec, |
| 343 | + name=f"tc_gather_hidden{hidden_dim}_numidx{total_indices}_block{block_size}", |
| 344 | + metadata={ |
| 345 | + "block_size": str(block_size), |
| 346 | + "hidden_dim": str(hidden_dim), |
| 347 | + "total_indices": str(total_indices), |
| 348 | + "dtype": str(dtype), |
| 349 | + "num_buffers": str(_NUM_BUFFERS), |
| 350 | + }, |
| 351 | + )( |
| 352 | + to_arr(num_blocks), |
| 353 | + to_arr(aligned_start), |
| 354 | + to_arr(end_idx), |
| 355 | + to_arr(local_start), |
| 356 | + idx_aligned_padded, |
| 357 | + idx_mod_8_padded, |
| 358 | + x, |
| 359 | + ) |
| 360 | + |
| 361 | + return res |
0 commit comments