1414from __future__ import annotations
1515
1616from functools import lru_cache , partial
17- from typing import Any , Callable , NamedTuple , Sequence
17+ from typing import Any
1818
1919from jax ._src import api
2020from jax ._src .api_util import (
@@ -60,33 +60,46 @@ def pmap(f, axis_name=None, *, in_axes=0, out_axes=0,
6060 f , axis_name , static_broadcasted_argnums , donate_argnums , in_axes , out_axes )
6161 if isinstance (axis_name , core ._TempAxisName ): # pylint: disable=protected-access
6262 axis_name = repr (axis_name )
63- fun = _pmap_wrap_init (f , static_broadcasted_tuple )
63+ wrapped_fun = _pmap_wrap_init (f , static_broadcasted_tuple )
6464
6565 def infer_params (* args , ** kwargs ):
66- p = _prepare_pmap (fun , in_axes , out_axes , static_broadcasted_tuple ,
67- donate_tuple , devices , backend , axis_size , args , kwargs )
66+ fun , dyn_argnums , dyn_args = _get_dyn_args (
67+ wrapped_fun , static_broadcasted_tuple , args )
68+ dyn_args_flat , dyn_args_tree = tree_flatten ((dyn_args , kwargs ))
69+ in_axes_flat = _get_in_axes_flat (
70+ in_axes , dyn_argnums , dyn_args , kwargs , len (dyn_args_flat ),
71+ dyn_args_tree )
72+ local_axis_size = _mapped_axis_size (dyn_args_flat , in_axes_flat )
73+ donated_invars = _get_donated_invars (
74+ donate_tuple , dyn_args_tree , len (dyn_args_flat ))
75+ fun , out_axes_thunk = flat_out_axes (fun , out_axes )
76+ flat_fun , out_tree = flatten_fun (fun , dyn_args_tree )
77+ global_axis_size = _get_global_axis_size (local_axis_size , devices ,
78+ backend , axis_size )
6879 trace_state_clean = core .trace_state_clean ()
69- mesh = Mesh (_get_devices (p , backend ), (axis_name ,))
80+ mesh = Mesh (
81+ _get_devices (devices , local_axis_size , global_axis_size , backend ),
82+ (axis_name ,))
7083 _pmapped , in_specs , out_specs = _cached_shard_map (
71- p . flat_fun , mesh , p . in_axes_flat , p . out_axes_thunk , axis_name )
84+ flat_fun , mesh , in_axes_flat , out_axes_thunk , axis_name )
7285 jitted_f = api .jit (
7386 _pmapped ,
74- donate_argnums = [i for i , val in enumerate (p . donated_invars ) if val ])
87+ donate_argnums = [i for i , val in enumerate (donated_invars ) if val ])
7588 if xb .process_count () > 1 :
7689 if trace_state_clean :
7790 flat_global_args = [
7891 host_local_array_to_global_array (arr , global_mesh = mesh , pspec = spec )
79- for arr , spec in zip (p . flat_args , in_specs )]
92+ for arr , spec in zip (dyn_args_flat , in_specs )]
8093 else :
8194 flat_global_args = mhu .host_local_array_to_global_array (
82- p . flat_args , mesh , list (in_specs ))
95+ dyn_args_flat , mesh , list (in_specs ))
8396 else :
84- flat_global_args = p . flat_args
85- return jitted_f , flat_global_args , p , mesh , out_specs , donate_tuple
97+ flat_global_args = dyn_args_flat
98+ return jitted_f , flat_global_args , dyn_args_tree , out_tree , mesh , out_specs , donate_tuple
8699
87100 @util .wraps (f )
88101 def wrapped (* args , ** kwargs ):
89- jitted_f , flat_global_args , p , mesh , out_specs , _ = infer_params (
102+ jitted_f , flat_global_args , _ , out_tree , mesh , out_specs , _ = infer_params (
90103 * args , ** kwargs )
91104 outs = jitted_f (* flat_global_args )
92105 if xb .process_count () > 1 :
@@ -97,15 +110,15 @@ def wrapped(*args, **kwargs):
97110 ]
98111 else :
99112 outs = mhu .global_array_to_host_local_array (outs , mesh , out_specs ())
100- return tree_unflatten (p . out_tree (), outs )
113+ return tree_unflatten (out_tree (), outs )
101114
102115 def lower (* args , ** kwargs ):
103- jitted_f , flat_global_args , p , _ , _ , donate_tuple = infer_params (
116+ jitted_f , flat_global_args , in_tree , out_tree , _ , _ , donate_tuple = infer_params (
104117 * args , ** kwargs )
105118 abstract_args = list (map (core .shaped_abstractify , flat_global_args ))
106- args_info = stages .make_args_info (p . in_tree , abstract_args , donate_tuple )
119+ args_info = stages .make_args_info (in_tree , abstract_args , donate_tuple )
107120 lowered = jitted_f .trace (* flat_global_args ).lower ()
108- lowered = stages .Lowered (lowered ._lowering , args_info , p . out_tree (),
121+ lowered = stages .Lowered (lowered ._lowering , args_info , out_tree (),
109122 no_kwargs = lowered ._no_kwargs )
110123 return lowered
111124 wrapped .lower = lower
@@ -137,14 +150,14 @@ def _handle_reshapes(f, in_axes, out_axes_thunk, *args, **kwargs):
137150 return tree_map (lambda x , ax : x if ax is None else lax .expand_dims (x , [ax ]),
138151 list (out ), list (out_axes_thunk ()))
139152
140- def _get_devices (p , backend ):
141- if backend is not None and p . devices is None :
153+ def _get_devices (devices , local_axis_size , global_axis_size , backend ):
154+ if backend is not None and devices is None :
142155 devs = xb .devices (backend = backend )
143156 else :
144- devs = xb .devices () if p . devices is None else p . devices
157+ devs = xb .devices () if devices is None else devices
145158 if xb .process_count () > 1 :
146- return devs [:p . global_axis_size ]
147- return devs [:p . local_axis_size ]
159+ return devs [:global_axis_size ]
160+ return devs [:local_axis_size ]
148161
149162
150163def _ensure_index_tuple (x ) -> tuple [int , ...]:
@@ -181,20 +194,6 @@ def _mapped_axis_size(args, in_axes):
181194 raise ValueError ("pmap requires at least one argument with a mapped axis." )
182195
183196
184- class PmapCallInfo (NamedTuple ):
185- flat_fun : lu .WrappedFun
186- in_tree : Any
187- out_tree : Callable
188- flat_args : Sequence [Any ]
189- donated_invars : Sequence [bool ]
190- in_axes_flat : Sequence [int | None ]
191- local_axis_size : int
192- out_axes_thunk : Callable
193- devices : Sequence | None
194- global_axis_size : int
195- is_explicit_global_axis_size : bool
196-
197-
198197def _get_global_axis_size (local_axis_size : int , in_devices , backend_name : str ,
199198 global_axis_size : int | None ):
200199 if (xb .process_count () == 1 and global_axis_size is not None and
@@ -221,36 +220,6 @@ def _get_global_axis_size(local_axis_size: int, in_devices, backend_name: str,
221220 return global_axis_size
222221
223222
224- def _prepare_pmap (fun : lu .WrappedFun , in_axes , out_axes ,
225- static_broadcasted_tuple , donate_tuple , in_devices ,
226- backend_name , axis_size , args , kwargs ):
227- fun , dyn_argnums , dyn_args = _get_dyn_args (fun , static_broadcasted_tuple , args )
228- dyn_args_flat , dyn_args_tree = tree_flatten ((dyn_args , kwargs ))
229- in_axes_flat = _get_in_axes_flat (in_axes , dyn_argnums , dyn_args , kwargs ,
230- len (dyn_args_flat ), dyn_args_tree )
231- local_axis_size = _mapped_axis_size (dyn_args_flat , in_axes_flat )
232- donated_invars = _get_donated_invars (donate_tuple , dyn_args_tree ,
233- len (dyn_args_flat ))
234-
235- fun , out_axes_thunk = flat_out_axes (fun , out_axes )
236- flat_fun , out_tree = flatten_fun (fun , dyn_args_tree )
237-
238- is_explicit_global_axis_size = axis_size is not None
239- global_axis_size = _get_global_axis_size (local_axis_size , in_devices ,
240- backend_name , axis_size )
241- return PmapCallInfo (flat_fun = flat_fun ,
242- in_tree = dyn_args_tree ,
243- out_tree = out_tree ,
244- flat_args = dyn_args_flat ,
245- donated_invars = donated_invars ,
246- in_axes_flat = in_axes_flat ,
247- local_axis_size = local_axis_size ,
248- out_axes_thunk = out_axes_thunk ,
249- devices = None if in_devices is None else tuple (in_devices ),
250- global_axis_size = global_axis_size ,
251- is_explicit_global_axis_size = is_explicit_global_axis_size )
252-
253-
254223def _pmap_wrap_init (f , static_broadcasted_tuple ):
255224 """Create a wrapped function with DebugInfo for pmap.
256225
0 commit comments