1616from collections .abc import Callable , Sequence
1717import enum
1818import functools
19- from typing import Any , TypeAlias
19+ from typing import TypeAlias , TypeVar , overload
2020
2121import jax
2222from jax import api_util
3232from jax ._src .lib .mlir .dialects import vector
3333from jax ._src .pallas import core as pallas_core
3434from jax ._src .pallas .mosaic import core as tpu_core
35- from jax ._src .pallas .mosaic import sc_lowering
3635from jax ._src .pallas .mosaic import lowering as tc_lowering
36+ from jax ._src .pallas .mosaic import sc_lowering
3737from jax ._src .state import primitives as state_primitives
3838from jax ._src .state import types as state_types
3939from jax .experimental .mosaic .dialects import tpu
4848TransformedRef : TypeAlias = state_types .TransformedRef
4949Ref : TypeAlias = state_types .AbstractRef | TransformedRef
5050
51+ _T = TypeVar ("_T" )
52+
5153load_p = jax_core .Primitive ("load" )
5254load_p .is_effectful = lambda params : True # type: ignore
5355
@@ -583,18 +585,32 @@ def _parallel_loop_lowering_rule(
583585 scf .yield_ (carry_out )
584586 return for_op .results
585587
586- CarryType : TypeAlias = Any
587588
589+ @overload
590+ def parallel_loop (
591+ lower : jax .typing .ArrayLike ,
592+ upper : jax .typing .ArrayLike ,
593+ step : jax .typing .ArrayLike = ...,
594+ * ,
595+ unroll : int = ...,
596+ carry : None = None ,
597+ ) -> Callable [[Callable [[jax .Array ], None ]], None ]:
598+ ...
599+
600+
601+ @overload
588602def parallel_loop (
589603 lower : jax .typing .ArrayLike ,
590604 upper : jax .typing .ArrayLike ,
591- step : jax .typing .ArrayLike = 1 ,
605+ step : jax .typing .ArrayLike = ... ,
592606 * ,
593- unroll : int = 1 ,
594- carry : CarryType | None = None ,
595- ) -> Callable [[Callable [[jax .Array , CarryType ], CarryType ] |
596- Callable [[jax .Array ], None ]],
597- CarryType | None ]:
607+ unroll : int = ...,
608+ carry : _T ,
609+ ) -> Callable [[Callable [[jax .Array , _T ], _T ]], _T ]:
610+ ...
611+
612+
613+ def parallel_loop (lower , upper , step = 1 , * , unroll = 1 , carry = None ):
598614 """A parallel loop decorator.
599615
600616 The decorated function forms the loop body. It is called with the current
@@ -609,46 +625,42 @@ def parallel_loop(
609625 Cross-iteration dependencies traceable via carried values are allowed. Refs
610626 may not be carried.
611627
612- Safe usage of carried value:
613- ```
614- @parallel_loop(0, 64, step=8, carry=jnp.int32(1))
615- def body(i, j):
616- # Writes are independent across iterations.
617- x_ref[pl.ds(i, 8)] = j + jnp.arange(8)
618- return j + 1
619- ```
620-
621- Any pytree can be carried. The final value is returned by the decorator:
622- ```
623- def body(i, my_tree: MyTree):
624- # Writes are independent across iterations.
625- x_ref[pl.ds(i, 8)] = my_tree.transform(jnp.arange(8))
626- return my_tree.step(i)
627- final_value = parallel_loop(0, 64, step=8, carry=MyTree())(body)
628- ```
629-
630- Undefined result:
631- ```
632- @parallel_loop(0, 64, step=4, carry=jnp.int32(1))
633- def body(i, j):
634- # Because the step size is 4, the array written is of size 8, and loop
635- # iterations may be reordered, the values in indices 4-59 of x_ref are
636- # unspecified after the loop. (The values in 0-3 and 60-63 are only written
637- # by the first and last iterations, so are well-defined.)
638- x_ref[pl.ds(i, 8)] = j + jnp.arange(8)
639- return j + 1
640- ```
641-
642- Unsafe read of "previous" iteration's write (don't do this):
643- ```
644- @parallel_loop(0, 64, 8, carry=jnp.int32(1))
645- def body(i, j):
646- # Unsafe because it depends on the side-effect of "previous" iterations,
647- # which may be executed in parallel or reordered.
648- mask = x_ref[pl.ds(0, 8)] < j
649- x_ref[pl.ds(0, 8)] += jnp.where(mask, j + jnp.arange(8), 0)
650- return j + 1
651- ```
628+ Safe usage of carried value::
629+
630+ @parallel_loop(0, 64, step=8, carry=jnp.int32(1))
631+ def body(i, j):
632+ # Writes are independent across iterations.
633+ x_ref[pl.ds(i, 8)] = j + jnp.arange(8)
634+ return j + 1
635+
636+ Any pytree can be carried. The final value is returned by the decorator::
637+
638+ def body(i, my_tree: MyTree):
639+ # Writes are independent across iterations.
640+ x_ref[pl.ds(i, 8)] = my_tree.transform(jnp.arange(8))
641+ return my_tree.step(i)
642+ final_value = parallel_loop(0, 64, step=8, carry=MyTree())(body)
643+
644+ Undefined result::
645+
646+ @parallel_loop(0, 64, step=4, carry=jnp.int32(1))
647+ def body(i, j):
648+ # Because the step size is 4, the array written is of size 8, and loop
649+ # iterations may be reordered, the values in indices 4-59 of x_ref are
650+ # unspecified after the loop. (The values in 0-3 and 60-63 are only
651+ # written by the first and last iterations, so are well-defined.)
652+ x_ref[pl.ds(i, 8)] = j + jnp.arange(8)
653+ return j + 1
654+
655+ Unsafe read of "previous" iteration's write (don't do this)::
656+
657+ @parallel_loop(0, 64, 8, carry=jnp.int32(1))
658+ def body(i, j):
659+ # Unsafe because it depends on the side-effect of "previous" iterations,
660+ # which may be executed in parallel or reordered.
661+ mask = x_ref[pl.ds(0, 8)] < j
662+ x_ref[pl.ds(0, 8)] += jnp.where(mask, j + jnp.arange(8), 0)
663+ return j + 1
652664
653665 Args:
654666 lower: The starting value of the loop index.
@@ -661,21 +673,19 @@ def body(i, j):
661673 A decorator that executes the given function in a parallel loop.
662674 """
663675
664- def decorator (
665- body : (Callable [[jax .Array , CarryType ], CarryType ] |
666- Callable [[jax .Array ], None ])
667- ) -> CarryType | None :
668- carries , carry_tree = jax .tree .flatten (carry )
676+ def decorator (body ):
677+ flat_carries , carry_tree = jax .tree .flatten (carry )
669678 def wrapped (idx , * carries ):
670679 if carry is None :
671- return body (idx ) or () # type: ignore
672- return jax .tree .leaves (body (idx , carry_tree .unflatten (carries ))) # type: ignore
680+ body (idx )
681+ return []
682+ return jax .tree .leaves (body (idx , carry_tree .unflatten (carries )))
673683 jaxpr , _ , consts = pe .trace_to_jaxpr_dynamic (
674684 lu .wrap_init (
675685 wrapped ,
676686 debug_info = api_util .debug_info ("parallel_loop" , body , (), {}),
677687 ),
678- [pallas_core .index_map_grid_aval , * (c .aval for c in carries )],
688+ [pallas_core .index_map_grid_aval , * (c .aval for c in flat_carries )],
679689 )
680690 disallowed_effects = effects .control_flow_allowed_effects .filter_not_in (
681691 jaxpr .effects
@@ -684,16 +694,16 @@ def wrapped(idx, *carries):
684694 raise NotImplementedError (
685695 f"Effects not supported in parallel_loop: { disallowed_effects } "
686696 )
687- flat_args , tree = jax .tree .flatten ((lower , upper , step , consts , carries ))
697+ flat_args , tree = jax .tree .flatten (
698+ (lower , upper , step , consts , flat_carries )
699+ )
688700 flat_result = parallel_loop_p .bind (
689- * flat_args ,
690- tree = tree ,
691- unroll = unroll ,
692- jaxpr = jaxpr ,
701+ * flat_args , tree = tree , unroll = unroll , jaxpr = jaxpr
693702 )
694703 if carry is None :
695704 return None
696705 return carry_tree .unflatten (flat_result )
706+
697707 return decorator
698708
699709
0 commit comments