Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions jax/_src/pallas/mosaic_gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1694,6 +1694,7 @@ class Layout(SomeLayout, enum.Enum):

SMEM_GMEM_COPY = enum.auto()
TMA_INDICES = enum.auto()
TMA_INDICES_4 = enum.auto()

# TODO(b/435159109): Remove this once LLVM regression is addressed.
_WGMMA_ACC_32BIT = enum.auto() # Temporarily exposed to work around LLVM bugs
Expand Down Expand Up @@ -1756,6 +1757,8 @@ def check_no_args():
)
case Layout.TMA_INDICES:
return mgpu.TMA_INDICES_LAYOUT
case Layout.TMA_INDICES_4:
return mgpu.TMA_INDICES_4_LAYOUT


class TMEMLayout(enum.Enum):
Expand Down
29 changes: 17 additions & 12 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1491,27 +1491,29 @@ def kernel(x_ref, o_ref, scratch_ref, barrier_ref):
np.testing.assert_array_equal(f(x), np.stack([x, x], axis=1))

@parameterized.parameters(
((),),
((plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128)),),
((), plgpu.Layout.TMA_INDICES),
((), plgpu.Layout.TMA_INDICES_4),
((plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128)), plgpu.Layout.TMA_INDICES),
)
def test_copy_gmem_to_smem_gather(self, transforms):
def test_copy_gmem_to_smem_gather(self, transforms, idxs_layout):
if not jtu.is_cuda_compute_capability_at_least("10.0"):
self.skipTest("Only works on a GPU with capability >= sm100")
dtype = jnp.int32
out_shape = (64, 128)
out_shape = (64 if idxs_layout == plgpu.Layout.TMA_INDICES else 4, 128)
shape = (128, 64 + out_shape[-1])
optimized_load = idxs_layout != plgpu.Layout.TMA_INDICES_4
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct(out_shape, dtype),
out_specs=plgpu.BlockSpec(memory_space=plgpu.SMEM, transforms=transforms),
in_specs=(
pl.BlockSpec(memory_space=plgpu.GMEM),
pl.BlockSpec(memory_space=plgpu.SMEM),
pl.BlockSpec(memory_space=plgpu.SMEM if optimized_load else plgpu.GMEM),
),
scratch_shapes=[plgpu.Barrier()],
)
def kernel(x_ref_gmem, idx_ref, o_ref, barrier_ref):
idxs = plgpu.load(idx_ref, (), layout=plgpu.Layout.TMA_INDICES)
idxs = plgpu.load(idx_ref, (), layout=idxs_layout, optimized=optimized_load)
plgpu.copy_gmem_to_smem(x_ref_gmem.at[idxs, 64:], o_ref, barrier_ref)
plgpu.barrier_wait(barrier_ref)

Expand All @@ -1520,19 +1522,22 @@ def kernel(x_ref_gmem, idx_ref, o_ref, barrier_ref):
np.testing.assert_array_equal(kernel(x, idx), x[idx, 64:])

@parameterized.parameters(
((),),
((plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128)),),
((), plgpu.Layout.TMA_INDICES),
((), plgpu.Layout.TMA_INDICES_4),
((plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128)), plgpu.Layout.TMA_INDICES),
)
def test_copy_smem_to_gmem_scatter(self, transforms):
def test_copy_smem_to_gmem_scatter(self, transforms, idxs_layout):
if not jtu.is_cuda_compute_capability_at_least("10.0"):
self.skipTest("Only works on a GPU with capability >= sm100")
# Make sure we can infer the layout in WG
indices_layout = tokens_layout = None
if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane:
indices_layout = plgpu.Layout.TMA_INDICES
indices_layout = idxs_layout
tokens_layout = plgpu.Layout.WGMMA
dtype = jnp.int32
shape = (64, 128)
num_indices = 64 if idxs_layout == plgpu.Layout.TMA_INDICES else 4
shape = (num_indices, 128)
optimized_load = idxs_layout != plgpu.Layout.TMA_INDICES_4
@functools.partial(
self.kernel,
out_shape=jax.ShapeDtypeStruct(shape, dtype),
Expand All @@ -1541,7 +1546,7 @@ def test_copy_smem_to_gmem_scatter(self, transforms):
def kernel(tokens_ref, perm_ref, o_ref, smem_ref):
smem_ref[...] = plgpu.load(tokens_ref, (), layout=tokens_layout, optimized=False)
plgpu.commit_smem()
idxs = plgpu.load(perm_ref, (), layout=indices_layout, optimized=False)
idxs = plgpu.load(perm_ref, (), layout=indices_layout, optimized=optimized_load)
plgpu.copy_smem_to_gmem(smem_ref, o_ref.at[idxs, :])
plgpu.wait_smem_to_gmem(0)

Expand Down
Loading