@@ -1473,6 +1473,44 @@ def _scan_lowering_rule(
14731473 return for_out
14741474
14751475
1476+ @register_lowering_rule (lax .while_p )
1477+ def _while_lowering_rule (
1478+ ctx : LoweringRuleContext ,
1479+ * args ,
1480+ cond_jaxpr ,
1481+ body_jaxpr ,
1482+ cond_nconsts ,
1483+ body_nconsts ,
1484+ ):
1485+ # First try to lower via a simpler fori loop, which may optimize better.
1486+ fori_jaxpr , err = pallas_utils .pattern_match_while_to_fori_loop (
1487+ cond_jaxpr , cond_nconsts , body_jaxpr , body_nconsts
1488+ )
1489+ del cond_jaxpr , body_jaxpr
1490+ if fori_jaxpr is None :
1491+ raise NotImplementedError (err )
1492+
1493+ if fori_jaxpr .constvars :
1494+ raise NotImplementedError
1495+
1496+ lb_aval , ub_aval , * _ = ctx .avals_in [body_nconsts :]
1497+ # Reflect the changes of the pattern matcher to the context.
1498+ avals_in = (
1499+ * ctx .avals_in [cond_nconsts :body_nconsts ],
1500+ ctx .avals_in [body_nconsts ], # the index
1501+ * ctx .avals_in [body_nconsts + 2 :],
1502+ )
1503+
1504+ avals_out = tuple (ctx .avals_out [2 :])
1505+ ctx = ctx .replace (avals_in = avals_in , avals_out = avals_out )
1506+ _ , consts , (lb , ub , * args ) = util .split_list (args , [cond_nconsts , body_nconsts ])
1507+
1508+ lb , ub = _ensure_ir_value (lb , lb_aval .dtype ), _ensure_ir_value (ub , ub_aval .dtype )
1509+ length = arith_dialect .subi (ub , lb )
1510+
1511+ for_out = _lower_jaxpr_to_for_loop (ctx , fori_jaxpr , lb , length , consts , * args , has_loop_index = True )
1512+ return (ub , ub , * for_out )
1513+
14761514@register_lowering_rule (lax .cond_p )
14771515def _cond_lowering_rule (ctx : LoweringRuleContext , index , * args , branches ):
14781516 index_aval , * _arg_avals = ctx .avals_in
0 commit comments