@@ -63,6 +63,8 @@ def pmap(f, axis_name=None, *, in_axes=0, out_axes=0,
6363 wrapped_fun = _pmap_wrap_init (f , static_broadcasted_tuple )
6464
6565 def infer_params (* args , ** kwargs ):
66+ process_count = xb .process_count (backend )
67+ trace_state_clean = core .trace_state_clean ()
6668 fun , dyn_argnums , dyn_args = _get_dyn_args (
6769 wrapped_fun , static_broadcasted_tuple , args )
6870 dyn_args_flat , dyn_args_tree = tree_flatten ((dyn_args , kwargs ))
@@ -72,20 +74,17 @@ def infer_params(*args, **kwargs):
7274 local_axis_size = _mapped_axis_size (dyn_args_flat , in_axes_flat )
7375 donated_invars = _get_donated_invars (
7476 donate_tuple , dyn_args_tree , len (dyn_args_flat ))
77+ mesh_devices = _get_mesh_devices (
78+ devices , backend , local_axis_size , axis_size , trace_state_clean )
7579 fun , out_axes_thunk = flat_out_axes (fun , out_axes )
7680 flat_fun , out_tree = flatten_fun (fun , dyn_args_tree )
77- global_axis_size = _get_global_axis_size (local_axis_size , devices ,
78- backend , axis_size )
79- trace_state_clean = core .trace_state_clean ()
80- mesh = Mesh (
81- _get_devices (devices , local_axis_size , global_axis_size , backend ),
82- (axis_name ,))
81+ mesh = Mesh (mesh_devices , (axis_name ,))
8382 _pmapped , in_specs , out_specs = _cached_shard_map (
8483 flat_fun , mesh , in_axes_flat , out_axes_thunk , axis_name )
8584 jitted_f = api .jit (
8685 _pmapped ,
8786 donate_argnums = [i for i , val in enumerate (donated_invars ) if val ])
88- if xb . process_count () > 1 :
87+ if process_count > 1 :
8988 if trace_state_clean :
9089 flat_global_args = [
9190 host_local_array_to_global_array (arr , global_mesh = mesh , pspec = spec )
@@ -95,15 +94,16 @@ def infer_params(*args, **kwargs):
9594 dyn_args_flat , mesh , list (in_specs ))
9695 else :
9796 flat_global_args = dyn_args_flat
98- return jitted_f , flat_global_args , dyn_args_tree , out_tree , mesh , out_specs , donate_tuple
97+ return (jitted_f , flat_global_args , dyn_args_tree , out_tree , mesh , out_specs ,
98+ donate_tuple , process_count , trace_state_clean )
9999
100100 @util .wraps (f )
101101 def wrapped (* args , ** kwargs ):
102- jitted_f , flat_global_args , _ , out_tree , mesh , out_specs , _ = infer_params (
103- * args , ** kwargs )
102+ ( jitted_f , flat_global_args , _ , out_tree , mesh , out_specs ,
103+ _ , process_count , trace_state_clean ) = infer_params ( * args , ** kwargs )
104104 outs = jitted_f (* flat_global_args )
105- if xb . process_count () > 1 :
106- if core . trace_state_clean () :
105+ if process_count > 1 :
106+ if trace_state_clean :
107107 outs = [
108108 global_array_to_host_local_array (out , global_mesh = mesh , pspec = spec )
109109 for out , spec in zip (outs , out_specs ())
@@ -113,7 +113,7 @@ def wrapped(*args, **kwargs):
113113 return tree_unflatten (out_tree (), outs )
114114
115115 def lower (* args , ** kwargs ):
116- jitted_f , flat_global_args , in_tree , out_tree , _ , _ , donate_tuple = infer_params (
116+ jitted_f , flat_global_args , in_tree , out_tree , _ , _ , donate_tuple , _ , _ = infer_params (
117117 * args , ** kwargs )
118118 abstract_args = list (map (core .shaped_abstractify , flat_global_args ))
119119 args_info = stages .make_args_info (in_tree , abstract_args , donate_tuple )
@@ -150,22 +150,6 @@ def _handle_reshapes(f, in_axes, out_axes_thunk, *args, **kwargs):
150150 return tree_map (lambda x , ax : x if ax is None else lax .expand_dims (x , [ax ]),
151151 list (out ), list (out_axes_thunk ()))
152152
153- def _get_devices (devices , local_axis_size , global_axis_size , backend ):
154- if backend is not None and devices is None :
155- devs = xb .devices (backend = backend )
156- else :
157- devs = xb .devices () if devices is None else devices
158- if xb .process_count () > 1 :
159- return devs [:global_axis_size ]
160- return devs [:local_axis_size ]
161-
162-
163- def _ensure_index_tuple (x ) -> tuple [int , ...]:
164- try :
165- return (int (x ),)
166- except TypeError :
167- return tuple (int (i ) for i in x )
168-
169153
170154def _mapped_axis_size (args , in_axes ):
171155 """Infer axis size from the first mapped argument.
@@ -194,32 +178,6 @@ def _mapped_axis_size(args, in_axes):
194178 raise ValueError ("pmap requires at least one argument with a mapped axis." )
195179
196180
197- def _get_global_axis_size (local_axis_size : int , in_devices , backend_name : str ,
198- global_axis_size : int | None ):
199- if (xb .process_count () == 1 and global_axis_size is not None and
200- global_axis_size != local_axis_size ):
201- raise ValueError (
202- f"Specified axis_size { global_axis_size } doesn't match received "
203- f"axis_size { local_axis_size } ." )
204-
205- if in_devices is not None and backend_name is None :
206- backend = xb .get_device_backend (in_devices [0 ])
207- else :
208- backend = xb .get_backend (backend_name )
209-
210- if global_axis_size is None :
211- if xb .process_count (backend ) == 1 :
212- global_axis_size = local_axis_size
213- elif in_devices is not None :
214- global_axis_size = len (in_devices )
215- else :
216- global_axis_size = local_axis_size * xb .process_count (backend )
217- assert all (
218- len (xb .local_devices (pi , backend )) == xb .local_device_count (backend )
219- for pi in range (xb .process_count (backend )))
220- return global_axis_size
221-
222-
223181def _pmap_wrap_init (f , static_broadcasted_tuple ):
224182 """Create a wrapped function with DebugInfo for pmap.
225183
@@ -387,6 +345,72 @@ def _get_donated_invars(donate_tuple, in_tree, num_flat_args):
387345 return (False ,) * num_flat_args
388346
389347
348+ @lru_cache
349+ def _get_mesh_devices (devices , backend , local_axis_size , axis_size ,
350+ trace_state_clean ):
351+ """Compute effective mesh devices based on context.
352+
353+ Args:
354+ devices: The mesh devices tuple.
355+ backend: The backend to use.
356+ local_axis_size: The local axis size (per-process).
357+ axis_size: User-specified global axis size (optional).
358+ trace_state_clean: True if in execution mode (not tracing).
359+
360+ Returns:
361+ Tuple of effective mesh devices sliced appropriately.
362+
363+ Raises:
364+ ValueError: If axis_size doesn't match inferred size in single-process.
365+ """
366+ process_count = xb .process_count (backend )
367+
368+ # Validate explicit axis_size in single-process mode
369+ if (process_count == 1 and axis_size is not None and
370+ axis_size != local_axis_size ):
371+ raise ValueError (
372+ f"Specified axis_size { axis_size } doesn't match received "
373+ f"axis_size { local_axis_size } ." )
374+
375+ # Compute global_axis_size
376+ if axis_size is not None :
377+ global_axis_size = axis_size
378+ elif process_count > 1 :
379+ global_axis_size = local_axis_size * process_count
380+ # Validate all processes have the same number of local devices
381+ assert all (
382+ len (xb .local_devices (pi , backend )) == xb .local_device_count (backend )
383+ for pi in range (process_count ))
384+ else :
385+ global_axis_size = local_axis_size
386+
387+ # Determine mesh devices
388+ if devices is not None :
389+ mesh_devices = devices
390+ elif process_count > 1 :
391+ # Multi-process: group devices by process (host) for optimal collective
392+ # performance. This matches the old pmap's device ordering which uses
393+ # local_devices(process_index) in a nested loop, ensuring devices from
394+ # the same host are contiguous in the mesh.
395+ # pylint: disable=g-complex-comprehension
396+ mesh_devices = tuple (
397+ d
398+ for process_index in range (process_count )
399+ for d in xb .local_devices (process_index , backend )
400+ )
401+ # pylint: enable=g-complex-comprehension
402+ elif backend is not None :
403+ mesh_devices = tuple (xb .devices (backend = backend ))
404+ else :
405+ mesh_devices = tuple (xb .devices ())
406+
407+ if not trace_state_clean and process_count > 1 :
408+ # Tracing in multihost: use local devices
409+ return tuple (xb .local_devices (backend = backend )[:local_axis_size ])
410+ else :
411+ return mesh_devices [:global_axis_size ]
412+
413+
390414@lru_cache
391415def _local_to_global_aval (local_aval , mesh , pspec ):
392416 pspec = sharding_impls .prepare_axis_resources (pspec , 'pspec to array_mapping' )
0 commit comments