1818
1919from jax ._src import api
2020from jax ._src .api_util import (
21- argnums_partial , donation_vector , flatten_fun ,
22- flat_out_axes , fun_signature , fun_sourceinfo )
21+ argnums_partial , donation_vector , fun_signature , fun_sourceinfo )
2322from jax ._src import array
2423from jax ._src import config
2524from jax ._src import core
3736from jax ._src .mesh import Mesh
3837from jax ._src .shard_map import _axes_to_pspec , _shard_map
3938from jax ._src .tree_util import (
40- broadcast_flattened_prefix_with_treedef , prefix_errors , tree_flatten , tree_map , tree_unflatten )
39+ broadcast_flattened_prefix_with_treedef , broadcast_prefix ,
40+ prefix_errors , tree_flatten , tree_map , tree_unflatten )
4141import numpy as np
4242
4343map , unsafe_map = util .safe_map , map
@@ -61,11 +61,13 @@ def pmap(f, axis_name=None, *, in_axes=0, out_axes=0,
6161 if isinstance (axis_name , core ._TempAxisName ): # pylint: disable=protected-access
6262 axis_name = repr (axis_name )
6363 wrapped_fun = _pmap_wrap_init (f , static_broadcasted_tuple )
64+ out_axes_flat , out_axes_tree = tree_flatten (out_axes )
65+ out_axes_flat = tuple (out_axes_flat )
6466
6567 def infer_params (* args , ** kwargs ):
6668 process_count = xb .process_count (backend )
6769 trace_state_clean = core .trace_state_clean ()
68- fun , dyn_argnums , dyn_args = _get_dyn_args (
70+ dyn_f , dyn_argnums , dyn_args = _get_dyn_args (
6971 wrapped_fun , static_broadcasted_tuple , args )
7072 dyn_args_flat , dyn_args_tree = tree_flatten ((dyn_args , kwargs ))
7173 in_axes_flat = _get_in_axes_flat (
@@ -76,11 +78,9 @@ def infer_params(*args, **kwargs):
7678 donate_tuple , dyn_args_tree , len (dyn_args_flat ))
7779 mesh_devices = _get_mesh_devices (
7880 devices , backend , local_axis_size , axis_size , trace_state_clean )
79- fun , out_axes_thunk = flat_out_axes (fun , out_axes )
80- flat_fun , out_tree = flatten_fun (fun , dyn_args_tree )
81- mesh = Mesh (mesh_devices , (axis_name ,))
82- _pmapped , in_specs , out_specs = _cached_shard_map (
83- flat_fun , mesh , in_axes_flat , out_axes_thunk , axis_name )
81+ _pmapped , in_specs , out_specs , mesh = _cached_shard_map (
82+ dyn_f , dyn_args_tree , in_axes_flat , out_axes_flat , out_axes_tree ,
83+ mesh_devices , axis_name )
8484 jitted_f = api .jit (
8585 _pmapped ,
8686 donate_argnums = [i for i , val in enumerate (donated_invars ) if val ])
@@ -94,61 +94,65 @@ def infer_params(*args, **kwargs):
9494 dyn_args_flat , mesh , list (in_specs ))
9595 else :
9696 flat_global_args = dyn_args_flat
97- return (jitted_f , flat_global_args , dyn_args_tree , out_tree , mesh , out_specs ,
97+ return (jitted_f , flat_global_args , dyn_args_tree , mesh , out_specs ,
9898 donate_tuple , process_count , trace_state_clean )
9999
100100 @util .wraps (f )
101101 def wrapped (* args , ** kwargs ):
102- (jitted_f , flat_global_args , _ , out_tree , mesh , out_specs ,
102+ (jitted_f , flat_global_args , _ , mesh , out_specs ,
103103 _ , process_count , trace_state_clean ) = infer_params (* args , ** kwargs )
104104 outs = jitted_f (* flat_global_args )
105105 if process_count > 1 :
106106 if trace_state_clean :
107- outs = [
108- global_array_to_host_local_array ( out , global_mesh = mesh , pspec = spec )
109- for out , spec in zip ( outs , out_specs ())
110- ]
107+ outs = tree_map (
108+ lambda out , spec : global_array_to_host_local_array (
109+ out , global_mesh = mesh , pspec = spec ),
110+ outs , out_specs , is_leaf = lambda x : x is None )
111111 else :
112- outs = mhu .global_array_to_host_local_array (outs , mesh , out_specs () )
113- return tree_unflatten ( out_tree (), outs )
112+ outs = mhu .global_array_to_host_local_array (outs , mesh , out_specs )
113+ return outs
114114
115115 def lower (* args , ** kwargs ):
116- jitted_f , flat_global_args , in_tree , out_tree , _ , _ , donate_tuple , _ , _ = infer_params (
116+ jitted_f , flat_global_args , in_tree , _ , _ , donate_tuple , _ , _ = infer_params (
117117 * args , ** kwargs )
118118 abstract_args = list (map (core .shaped_abstractify , flat_global_args ))
119119 args_info = stages .make_args_info (in_tree , abstract_args , donate_tuple )
120120 lowered = jitted_f .trace (* flat_global_args ).lower ()
121- lowered = stages .Lowered (lowered ._lowering , args_info , out_tree () ,
121+ lowered = stages .Lowered (lowered ._lowering , args_info , lowered . out_tree ,
122122 no_kwargs = lowered ._no_kwargs )
123123 return lowered
124124 wrapped .lower = lower
125125 return wrapped
126126
127127
128128@lu .cache
129- def _cached_shard_map (flat_fun , mesh , in_axes_flat , out_axes_thunk , axis_name ):
130- f_transformed = flat_fun .f_transformed
131- def reset_stores_f_transformed (* args , ** kwargs ):
132- for store in flat_fun .stores :
133- if store is not None :
134- store .reset ()
135- return f_transformed (* args , ** kwargs )
136- flat_fun .f_transformed = reset_stores_f_transformed
129+ def _cached_shard_map (fun , in_tree , in_axes_flat ,
130+ out_axes_flat , out_axes_tree , mesh_devices , axis_name ):
131+ mesh = Mesh (mesh_devices , (axis_name ,))
132+ out_axes = tree_unflatten (out_axes_tree , list (out_axes_flat ))
137133 in_specs = tuple (map (partial (_axes_to_pspec , axis_name ), in_axes_flat ))
138- out_specs = lambda : map (partial (_axes_to_pspec , axis_name ), out_axes_thunk ())
139- fun = _handle_reshapes (flat_fun , in_axes_flat , out_axes_thunk )
140- return (_shard_map (fun .call_wrapped , mesh = mesh , in_specs = in_specs ,
141- out_specs = out_specs , check_vma = False ,
142- axis_names = set (mesh .axis_names )),
143- in_specs , out_specs )
144-
145- @lu .transformation2
146- def _handle_reshapes (f , in_axes , out_axes_thunk , * args , ** kwargs ):
147- args = tree_map (lambda x , ax : x if ax is None else lax .squeeze (x , [ax ]),
148- list (args ), list (in_axes ))
149- out = f (* args )
150- return tree_map (lambda x , ax : x if ax is None else lax .expand_dims (x , [ax ]),
151- list (out ), list (out_axes_thunk ()))
134+ out_specs = tree_map (
135+ partial (_axes_to_pspec , axis_name ), out_axes , is_leaf = lambda x : x is None
136+ )
137+ def _fun (* flat_args ):
138+ args = tree_map (
139+ lambda x , ax : x if ax is None else lax .squeeze (x , [ax ]),
140+ flat_args ,
141+ in_axes_flat ,
142+ )
143+ args , kwargs = tree_unflatten (in_tree , args )
144+ out = fun .call_wrapped (* args , ** kwargs )
145+ out_flat , out_tree = tree_flatten (out )
146+ out_axes_flat = broadcast_prefix (out_axes , out , is_leaf = lambda x : x is None )
147+ out_flat = tree_map (
148+ lambda x , ax : x if ax is None else lax .expand_dims (x , [ax ]),
149+ out_flat ,
150+ out_axes_flat ,
151+ )
152+ return tree_unflatten (out_tree , out_flat )
153+ _pmapped = _shard_map (_fun , mesh = mesh , in_specs = in_specs , out_specs = out_specs ,
154+ check_vma = False , axis_names = set (mesh .axis_names ))
155+ return (_pmapped , in_specs , out_specs , mesh )
152156
153157
154158def _mapped_axis_size (args , in_axes ):
0 commit comments