Skip to content

Commit 85edead

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas:sc] Added overloads for parallel_loop and slightly reformatted the docstring
PiperOrigin-RevId: 811397784
1 parent c531c41 commit 85edead

File tree

1 file changed

+72
-62
lines changed

1 file changed

+72
-62
lines changed

jax/_src/pallas/mosaic/sc_primitives.py

Lines changed: 72 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from collections.abc import Callable, Sequence
1717
import enum
1818
import functools
19-
from typing import Any, TypeAlias
19+
from typing import TypeAlias, TypeVar, overload
2020

2121
import jax
2222
from jax import api_util
@@ -32,8 +32,8 @@
3232
from jax._src.lib.mlir.dialects import vector
3333
from jax._src.pallas import core as pallas_core
3434
from jax._src.pallas.mosaic import core as tpu_core
35-
from jax._src.pallas.mosaic import sc_lowering
3635
from jax._src.pallas.mosaic import lowering as tc_lowering
36+
from jax._src.pallas.mosaic import sc_lowering
3737
from jax._src.state import primitives as state_primitives
3838
from jax._src.state import types as state_types
3939
from jax.experimental.mosaic.dialects import tpu
@@ -48,6 +48,8 @@
4848
TransformedRef: TypeAlias = state_types.TransformedRef
4949
Ref: TypeAlias = state_types.AbstractRef | TransformedRef
5050

51+
_T = TypeVar("_T")
52+
5153
load_p = jax_core.Primitive("load")
5254
load_p.is_effectful = lambda params: True # type: ignore
5355

@@ -583,18 +585,32 @@ def _parallel_loop_lowering_rule(
583585
scf.yield_(carry_out)
584586
return for_op.results
585587

586-
CarryType: TypeAlias = Any
587588

589+
@overload
590+
def parallel_loop(
591+
lower: jax.typing.ArrayLike,
592+
upper: jax.typing.ArrayLike,
593+
step: jax.typing.ArrayLike = ...,
594+
*,
595+
unroll: int = ...,
596+
carry: None = None,
597+
) -> Callable[[Callable[[jax.Array], None]], None]:
598+
...
599+
600+
601+
@overload
588602
def parallel_loop(
589603
lower: jax.typing.ArrayLike,
590604
upper: jax.typing.ArrayLike,
591-
step: jax.typing.ArrayLike = 1,
605+
step: jax.typing.ArrayLike = ...,
592606
*,
593-
unroll: int = 1,
594-
carry: CarryType | None = None,
595-
) -> Callable[[Callable[[jax.Array, CarryType], CarryType] |
596-
Callable[[jax.Array], None]],
597-
CarryType | None]:
607+
unroll: int = ...,
608+
carry: _T,
609+
) -> Callable[[Callable[[jax.Array, _T], _T]], _T]:
610+
...
611+
612+
613+
def parallel_loop(lower, upper, step=1, *, unroll=1, carry=None):
598614
"""A parallel loop decorator.
599615
600616
The decorated function forms the loop body. It is called with the current
@@ -609,46 +625,42 @@ def parallel_loop(
609625
Cross-iteration dependencies traceable via carried values are allowed. Refs
610626
may not be carried.
611627
612-
Safe usage of carried value:
613-
```
614-
@parallel_loop(0, 64, step=8, carry=jnp.int32(1))
615-
def body(i, j):
616-
# Writes are independent across iterations.
617-
x_ref[pl.ds(i, 8)] = j + jnp.arange(8)
618-
return j + 1
619-
```
620-
621-
Any pytree can be carried. The final value is returned by the decorator:
622-
```
623-
def body(i, my_tree: MyTree):
624-
# Writes are independent across iterations.
625-
x_ref[pl.ds(i, 8)] = my_tree.transform(jnp.arange(8))
626-
return my_tree.step(i)
627-
final_value = parallel_loop(0, 64, step=8, carry=MyTree())(body)
628-
```
629-
630-
Undefined result:
631-
```
632-
@parallel_loop(0, 64, step=4, carry=jnp.int32(1))
633-
def body(i, j):
634-
# Because the step size is 4, the array written is of size 8, and loop
635-
# iterations may be reordered, the values in indices 4-59 of x_ref are
636-
# unspecified after the loop. (The values in 0-3 and 60-63 are only written
637-
# by the first and last iterations, so are well-defined.)
638-
x_ref[pl.ds(i, 8)] = j + jnp.arange(8)
639-
return j + 1
640-
```
641-
642-
Unsafe read of "previous" iteration's write (don't do this):
643-
```
644-
@parallel_loop(0, 64, 8, carry=jnp.int32(1))
645-
def body(i, j):
646-
# Unsafe because it depends on the side-effect of "previous" iterations,
647-
# which may be executed in parallel or reordered.
648-
mask = x_ref[pl.ds(0, 8)] < j
649-
x_ref[pl.ds(0, 8)] += jnp.where(mask, j + jnp.arange(8), 0)
650-
return j + 1
651-
```
628+
Safe usage of carried value::
629+
630+
@parallel_loop(0, 64, step=8, carry=jnp.int32(1))
631+
def body(i, j):
632+
# Writes are independent across iterations.
633+
x_ref[pl.ds(i, 8)] = j + jnp.arange(8)
634+
return j + 1
635+
636+
Any pytree can be carried. The final value is returned by the decorator::
637+
638+
def body(i, my_tree: MyTree):
639+
# Writes are independent across iterations.
640+
x_ref[pl.ds(i, 8)] = my_tree.transform(jnp.arange(8))
641+
return my_tree.step(i)
642+
final_value = parallel_loop(0, 64, step=8, carry=MyTree())(body)
643+
644+
Undefined result::
645+
646+
@parallel_loop(0, 64, step=4, carry=jnp.int32(1))
647+
def body(i, j):
648+
# Because the step size is 4, the array written is of size 8, and loop
649+
# iterations may be reordered, the values in indices 4-59 of x_ref are
650+
# unspecified after the loop. (The values in 0-3 and 60-63 are only
651+
# written by the first and last iterations, so are well-defined.)
652+
x_ref[pl.ds(i, 8)] = j + jnp.arange(8)
653+
return j + 1
654+
655+
Unsafe read of "previous" iteration's write (don't do this)::
656+
657+
@parallel_loop(0, 64, 8, carry=jnp.int32(1))
658+
def body(i, j):
659+
# Unsafe because it depends on the side-effect of "previous" iterations,
660+
# which may be executed in parallel or reordered.
661+
mask = x_ref[pl.ds(0, 8)] < j
662+
x_ref[pl.ds(0, 8)] += jnp.where(mask, j + jnp.arange(8), 0)
663+
return j + 1
652664
653665
Args:
654666
lower: The starting value of the loop index.
@@ -661,21 +673,19 @@ def body(i, j):
661673
A decorator that executes the given function in a parallel loop.
662674
"""
663675

664-
def decorator(
665-
body: (Callable[[jax.Array, CarryType], CarryType] |
666-
Callable[[jax.Array], None])
667-
) -> CarryType | None:
668-
carries, carry_tree = jax.tree.flatten(carry)
676+
def decorator(body):
677+
flat_carries, carry_tree = jax.tree.flatten(carry)
669678
def wrapped(idx, *carries):
670679
if carry is None:
671-
return body(idx) or () # type: ignore
672-
return jax.tree.leaves(body(idx, carry_tree.unflatten(carries))) # type: ignore
680+
body(idx)
681+
return []
682+
return jax.tree.leaves(body(idx, carry_tree.unflatten(carries)))
673683
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
674684
lu.wrap_init(
675685
wrapped,
676686
debug_info=api_util.debug_info("parallel_loop", body, (), {}),
677687
),
678-
[pallas_core.index_map_grid_aval, *(c.aval for c in carries)],
688+
[pallas_core.index_map_grid_aval, *(c.aval for c in flat_carries)],
679689
)
680690
disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(
681691
jaxpr.effects
@@ -684,16 +694,16 @@ def wrapped(idx, *carries):
684694
raise NotImplementedError(
685695
f"Effects not supported in parallel_loop: {disallowed_effects}"
686696
)
687-
flat_args, tree = jax.tree.flatten((lower, upper, step, consts, carries))
697+
flat_args, tree = jax.tree.flatten(
698+
(lower, upper, step, consts, flat_carries)
699+
)
688700
flat_result = parallel_loop_p.bind(
689-
*flat_args,
690-
tree=tree,
691-
unroll=unroll,
692-
jaxpr=jaxpr,
701+
*flat_args, tree=tree, unroll=unroll, jaxpr=jaxpr
693702
)
694703
if carry is None:
695704
return None
696705
return carry_tree.unflatten(flat_result)
706+
697707
return decorator
698708

699709

0 commit comments

Comments
 (0)