Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 11 additions & 21 deletions jax/_src/pallas/mosaic/sc_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,8 +546,8 @@ def _masked_cummax_abstract_eval(x, mask):
return x


def _masked_cumop_lowering_rule(ctx: sc_lowering.LoweringRuleContext, x, mask,
*, reduction_kind: str):
def _masked_cumop_lowering_rule(ctx: sc_lowering.LoweringRuleContext, x,
*maybe_mask, reduction_kind: str):
sign_bit_vec = None
# tpu.scan comparisons assume unsigned int predicates, so we compare
# with the sign bit flipped.
Expand All @@ -556,6 +556,10 @@ def _masked_cumop_lowering_rule(ctx: sc_lowering.LoweringRuleContext, x, mask,
sign_bit_vec = vector.broadcast(
x.type, arith.constant(i32, ir.IntegerAttr.get(i32, 0x80000000)))
x = arith.xori(x, sign_bit_vec)
match maybe_mask:
case (mask,): ...
case _:
mask = None
result = tpu.scan(
x.type, x, ir.Attribute.parse(f"#tpu.reduction_kind<{reduction_kind}>"),
mask=mask)
Expand All @@ -579,12 +583,8 @@ def _reduce_op_lowering_rule(ctx: sc_lowering.LoweringRuleContext, x, axes,
raise NotImplementedError(
f"reductions require axes to be (0,) on SparseCore, but got {axes}.")
vec_dim = ctx.avals_in[0].shape[0]
i1t = ir.IntegerType.get_signless(1)
c1 = arith.constant(i1t, ir.IntegerAttr.get(i1t, 1))
x_shp = ctx.avals_in[0].shape
c1v = vector.broadcast(ir.VectorType.get(x_shp, c1.type), c1)
return vector.extract(
_masked_cumop_lowering_rule(ctx, x, c1v, reduction_kind=reduction_kind),
_masked_cumop_lowering_rule(ctx, x, reduction_kind=reduction_kind),
[], [vec_dim - 1])

sc_lowering.register_lowering_rule(
Expand Down Expand Up @@ -613,9 +613,7 @@ def cummax(x: jax.Array, *, mask: jax.Array | None = None) -> jax.Array:
"""
if x.ndim != 1:
raise NotImplementedError(f"cummax: x={x.aval} must be rank 1")
if mask is None:
mask = lax.full(x.shape, True)
return masked_cummax_p.bind(x, mask)
return masked_cummax_p.bind(x, *() if mask is None else (mask,))


def cummin(x: jax.Array, *, mask: jax.Array | None = None) -> jax.Array:
Expand All @@ -633,9 +631,7 @@ def cummin(x: jax.Array, *, mask: jax.Array | None = None) -> jax.Array:
"""
if x.ndim != 1:
raise NotImplementedError(f"cummin: x={x.aval} must be rank 1")
if mask is None:
mask = lax.full(x.shape, True)
return masked_cummin_p.bind(x, mask)
return masked_cummin_p.bind(x, *() if mask is None else (mask,))


@sc_lowering.register_lowering_rule(lax.cumsum_p)
Expand All @@ -647,11 +643,7 @@ def _cumsum_lowering_rule(ctx: sc_lowering.LoweringRuleContext, x, axis,
raise NotImplementedError(f"SC cumsum: x={ctx.avals_in[0]} must be rank 1")
if reverse:
raise NotImplementedError("SC cumsum: reverse=True is not yet supported")
i1t = ir.IntegerType.get_signless(1)
c1 = arith.constant(i1t, ir.IntegerAttr.get(i1t, 1))
c1v = vector.broadcast(ir.VectorType.get(x.type.shape, c1.type), c1)
return tpu.scan(
x.type, x, ir.Attribute.parse("#tpu.reduction_kind<sum>"), mask=c1v)
return tpu.scan(x.type, x, ir.Attribute.parse("#tpu.reduction_kind<sum>"))


def cumsum(x: jax.Array, *, mask: jax.Array | None = None) -> jax.Array:
Expand All @@ -666,9 +658,7 @@ def cumsum(x: jax.Array, *, mask: jax.Array | None = None) -> jax.Array:
"""
if x.ndim != 1:
raise NotImplementedError(f"cumsum: x={x.aval} must be rank 1")
if mask is None:
mask = lax.full(x.shape, True)
return masked_cumsum_p.bind(x, mask)
return masked_cumsum_p.bind(x, *() if mask is None else (mask,))


masked_sort_p = jax_core.Primitive("masked_sort")
Expand Down
25 changes: 18 additions & 7 deletions tests/pallas/tpu_sparsecore_pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1681,13 +1681,18 @@ def kernel(x_ref, o_ref):
x = jnp.arange(math.prod(in_shape), dtype=dtype).reshape(in_shape)
np.testing.assert_array_equal(kernel(x), x.reshape(out_shape))

@parameterized.product(dtype=[jnp.int32, jnp.float32])
def test_cumsum(self, dtype):
x = jnp.arange(self.sc_info.num_lanes, dtype=dtype)
@parameterized.product(size=[8, 16, 128], dtype=[jnp.int32, jnp.float32])
def test_cumsum(self, size, dtype):
if not jtu.is_cloud_tpu_at_least(2026, 6, 9) and size != self.num_lanes:
self.skipTest("Needs a newer libtpu")

x = jnp.arange(size, dtype=dtype)

@self.vector_subcore_kernel(
out_shape=x,
compiler_params=pltpu.CompilerParams(needs_layout_passes=False),
compiler_params=pltpu.CompilerParams(
needs_layout_passes=jtu.is_cloud_tpu_at_least(2026, 6, 9)
),
)
def kernel(x_ref, o_ref):
o_ref[...] = jnp.cumsum(x_ref[...])
Expand All @@ -1700,7 +1705,9 @@ def test_reductions(self, dtype, op):
x = jnp.arange(self.sc_info.num_lanes, dtype=dtype)
@self.vector_subcore_kernel(
out_shape=x,
compiler_params=pltpu.CompilerParams(needs_layout_passes=False),
compiler_params=pltpu.CompilerParams(
needs_layout_passes=jtu.is_cloud_tpu_at_least(2026, 6, 9)
),
)
def kernel(x_ref, o_ref):
o_ref[...] = jnp.full(o_ref.shape, op(x_ref[...]))
Expand Down Expand Up @@ -1743,7 +1750,9 @@ def test_masked_cumsum(self, dtype):

@self.vector_subcore_kernel(
out_shape=x,
compiler_params=pltpu.CompilerParams(needs_layout_passes=False),
compiler_params=pltpu.CompilerParams(
needs_layout_passes=jtu.is_cloud_tpu_at_least(2026, 6, 9)
),
)
def kernel(x_ref, o_ref):
o_ref[...] = plsc.cumsum(x_ref[...], mask=(x_ref[...] % 2) == 1)
Expand All @@ -1757,7 +1766,9 @@ def test_masked_cummax(self, dtype):

@self.vector_subcore_kernel(
out_shape=x,
compiler_params=pltpu.CompilerParams(needs_layout_passes=False),
compiler_params=pltpu.CompilerParams(
needs_layout_passes=jtu.is_cloud_tpu_at_least(2026, 6, 9)
),
)
def kernel(x_ref, o_ref):
o_ref[...] = plsc.cummax(x_ref[...], mask=(x_ref[...] % 2) == 1)
Expand Down
Loading