Skip to content

Commit 8d84f28

Browse files
cperivolGoogle-ML-Automation
authored andcommitted
[pallas mgpu] Lowering for while loops as long as they are secretly for loops.
PiperOrigin-RevId: 698427307
1 parent 439d34d commit 8d84f28

File tree

2 files changed

+54
-0
lines changed

2 files changed

+54
-0
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1473,6 +1473,44 @@ def _scan_lowering_rule(
14731473
return for_out
14741474

14751475

1476+
@register_lowering_rule(lax.while_p)
1477+
def _while_lowering_rule(
1478+
ctx: LoweringRuleContext,
1479+
*args,
1480+
cond_jaxpr,
1481+
body_jaxpr,
1482+
cond_nconsts,
1483+
body_nconsts,
1484+
):
1485+
# First try to lower via a simpler fori loop, which may optimize better.
1486+
fori_jaxpr, err = pallas_utils.pattern_match_while_to_fori_loop(
1487+
cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts
1488+
)
1489+
del cond_jaxpr, body_jaxpr
1490+
if fori_jaxpr is None:
1491+
raise NotImplementedError(err)
1492+
1493+
if fori_jaxpr.constvars:
1494+
raise NotImplementedError
1495+
1496+
lb_aval, ub_aval, *_ = ctx.avals_in[body_nconsts:]
1497+
# Reflect the changes of the pattern matcher to the context.
1498+
avals_in = (
1499+
*ctx.avals_in[cond_nconsts:body_nconsts],
1500+
ctx.avals_in[body_nconsts], # the index
1501+
*ctx.avals_in[body_nconsts + 2:],
1502+
)
1503+
1504+
avals_out = tuple(ctx.avals_out[2:])
1505+
ctx = ctx.replace(avals_in=avals_in, avals_out=avals_out)
1506+
_, consts, (lb, ub, *args) = util.split_list(args, [cond_nconsts, body_nconsts])
1507+
1508+
lb, ub = _ensure_ir_value(lb, lb_aval.dtype), _ensure_ir_value(ub, ub_aval.dtype)
1509+
length = arith_dialect.subi(ub, lb)
1510+
1511+
for_out = _lower_jaxpr_to_for_loop(ctx, fori_jaxpr, lb, length, consts, *args, has_loop_index=True)
1512+
return (ub, ub, *for_out)
1513+
14761514
@register_lowering_rule(lax.cond_p)
14771515
def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches):
14781516
index_aval, *_arg_avals = ctx.avals_in

tests/pallas/mosaic_gpu_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,22 @@ def kernel(o_ref):
676676

677677
np.testing.assert_array_equal(kernel(), jnp.full([256], 5, dtype=jnp.int32))
678678

679+
def test_fori_loop_dynamic_bounds(self):
680+
681+
@functools.partial(
682+
pl.pallas_call,
683+
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
684+
grid=(1,)
685+
)
686+
def kernel(o_ref):
687+
zero = pl.program_id(0)
688+
# Equivalent to 2 + 3.
689+
o_ref[...] = jax.lax.broadcast(
690+
jax.lax.fori_loop(2 + zero, 4 + zero, lambda i, x: x + i, 0), o_ref.shape
691+
)
692+
693+
np.testing.assert_array_equal(kernel(), jnp.full([256], 5, dtype=jnp.int32))
694+
679695
def test_fori_loop_tuple(self):
680696
@functools.partial(
681697
pl.pallas_call,

0 commit comments

Comments
 (0)