Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 79 additions & 55 deletions jax/_src/pmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def pmap(f, axis_name=None, *, in_axes=0, out_axes=0,
wrapped_fun = _pmap_wrap_init(f, static_broadcasted_tuple)

def infer_params(*args, **kwargs):
process_count = xb.process_count(backend)
trace_state_clean = core.trace_state_clean()
fun, dyn_argnums, dyn_args = _get_dyn_args(
wrapped_fun, static_broadcasted_tuple, args)
dyn_args_flat, dyn_args_tree = tree_flatten((dyn_args, kwargs))
Expand All @@ -72,20 +74,17 @@ def infer_params(*args, **kwargs):
local_axis_size = _mapped_axis_size(dyn_args_flat, in_axes_flat)
donated_invars = _get_donated_invars(
donate_tuple, dyn_args_tree, len(dyn_args_flat))
mesh_devices = _get_mesh_devices(
devices, backend, local_axis_size, axis_size, trace_state_clean)
fun, out_axes_thunk = flat_out_axes(fun, out_axes)
flat_fun, out_tree = flatten_fun(fun, dyn_args_tree)
global_axis_size = _get_global_axis_size(local_axis_size, devices,
backend, axis_size)
trace_state_clean = core.trace_state_clean()
mesh = Mesh(
_get_devices(devices, local_axis_size, global_axis_size, backend),
(axis_name,))
mesh = Mesh(mesh_devices, (axis_name,))
_pmapped, in_specs, out_specs = _cached_shard_map(
flat_fun, mesh, in_axes_flat, out_axes_thunk, axis_name)
jitted_f = api.jit(
_pmapped,
donate_argnums=[i for i, val in enumerate(donated_invars) if val])
if xb.process_count() > 1:
if process_count > 1:
if trace_state_clean:
flat_global_args = [
host_local_array_to_global_array(arr, global_mesh=mesh, pspec=spec)
Expand All @@ -95,15 +94,16 @@ def infer_params(*args, **kwargs):
dyn_args_flat, mesh, list(in_specs))
else:
flat_global_args = dyn_args_flat
return jitted_f, flat_global_args, dyn_args_tree, out_tree, mesh, out_specs, donate_tuple
return (jitted_f, flat_global_args, dyn_args_tree, out_tree, mesh, out_specs,
donate_tuple, process_count, trace_state_clean)

@util.wraps(f)
def wrapped(*args, **kwargs):
jitted_f, flat_global_args, _, out_tree, mesh, out_specs, _ = infer_params(
*args, **kwargs)
(jitted_f, flat_global_args, _, out_tree, mesh, out_specs,
_, process_count, trace_state_clean) = infer_params(*args, **kwargs)
outs = jitted_f(*flat_global_args)
if xb.process_count() > 1:
if core.trace_state_clean():
if process_count > 1:
if trace_state_clean:
outs = [
global_array_to_host_local_array(out, global_mesh=mesh, pspec=spec)
for out, spec in zip(outs, out_specs())
Expand All @@ -113,7 +113,7 @@ def wrapped(*args, **kwargs):
return tree_unflatten(out_tree(), outs)

def lower(*args, **kwargs):
jitted_f, flat_global_args, in_tree, out_tree, _, _, donate_tuple = infer_params(
jitted_f, flat_global_args, in_tree, out_tree, _, _, donate_tuple, _, _ = infer_params(
*args, **kwargs)
abstract_args = list(map(core.shaped_abstractify, flat_global_args))
args_info = stages.make_args_info(in_tree, abstract_args, donate_tuple)
Expand Down Expand Up @@ -150,22 +150,6 @@ def _handle_reshapes(f, in_axes, out_axes_thunk, *args, **kwargs):
return tree_map(lambda x, ax: x if ax is None else lax.expand_dims(x, [ax]),
list(out), list(out_axes_thunk()))

def _get_devices(devices, local_axis_size, global_axis_size, backend):
if backend is not None and devices is None:
devs = xb.devices(backend=backend)
else:
devs = xb.devices() if devices is None else devices
if xb.process_count() > 1:
return devs[:global_axis_size]
return devs[:local_axis_size]


def _ensure_index_tuple(x) -> tuple[int, ...]:
try:
return (int(x),)
except TypeError:
return tuple(int(i) for i in x)


def _mapped_axis_size(args, in_axes):
"""Infer axis size from the first mapped argument.
Expand Down Expand Up @@ -194,32 +178,6 @@ def _mapped_axis_size(args, in_axes):
raise ValueError("pmap requires at least one argument with a mapped axis.")


def _get_global_axis_size(local_axis_size: int, in_devices, backend_name: str,
global_axis_size: int | None):
if (xb.process_count() == 1 and global_axis_size is not None and
global_axis_size != local_axis_size):
raise ValueError(
f"Specified axis_size {global_axis_size} doesn't match received "
f"axis_size {local_axis_size}.")

if in_devices is not None and backend_name is None:
backend = xb.get_device_backend(in_devices[0])
else:
backend = xb.get_backend(backend_name)

if global_axis_size is None:
if xb.process_count(backend) == 1:
global_axis_size = local_axis_size
elif in_devices is not None:
global_axis_size = len(in_devices)
else:
global_axis_size = local_axis_size * xb.process_count(backend)
assert all(
len(xb.local_devices(pi, backend)) == xb.local_device_count(backend)
for pi in range(xb.process_count(backend)))
return global_axis_size


def _pmap_wrap_init(f, static_broadcasted_tuple):
"""Create a wrapped function with DebugInfo for pmap.

Expand Down Expand Up @@ -387,6 +345,72 @@ def _get_donated_invars(donate_tuple, in_tree, num_flat_args):
return (False,) * num_flat_args


@lru_cache
def _get_mesh_devices(devices, backend, local_axis_size, axis_size,
trace_state_clean):
"""Compute effective mesh devices based on context.

Args:
devices: The mesh devices tuple.
backend: The backend to use.
local_axis_size: The local axis size (per-process).
axis_size: User-specified global axis size (optional).
trace_state_clean: True if in execution mode (not tracing).

Returns:
Tuple of effective mesh devices sliced appropriately.

Raises:
ValueError: If axis_size doesn't match inferred size in single-process.
"""
process_count = xb.process_count(backend)

# Validate explicit axis_size in single-process mode
if (process_count == 1 and axis_size is not None and
axis_size != local_axis_size):
raise ValueError(
f"Specified axis_size {axis_size} doesn't match received "
f"axis_size {local_axis_size}.")

# Compute global_axis_size
if axis_size is not None:
global_axis_size = axis_size
elif process_count > 1:
global_axis_size = local_axis_size * process_count
# Validate all processes have the same number of local devices
assert all(
len(xb.local_devices(pi, backend)) == xb.local_device_count(backend)
for pi in range(process_count))
else:
global_axis_size = local_axis_size

# Determine mesh devices
if devices is not None:
mesh_devices = devices
elif process_count > 1:
# Multi-process: group devices by process (host) for optimal collective
# performance. This matches the old pmap's device ordering which uses
# local_devices(process_index) in a nested loop, ensuring devices from
# the same host are contiguous in the mesh.
# pylint: disable=g-complex-comprehension
mesh_devices = tuple(
d
for process_index in range(process_count)
for d in xb.local_devices(process_index, backend)
)
# pylint: enable=g-complex-comprehension
elif backend is not None:
mesh_devices = tuple(xb.devices(backend=backend))
else:
mesh_devices = tuple(xb.devices())

if not trace_state_clean and process_count > 1:
# Tracing in multihost: use local devices
return tuple(xb.local_devices(backend=backend)[:local_axis_size])
else:
return mesh_devices[:global_axis_size]


@lru_cache
def _local_to_global_aval(local_aval, mesh, pspec):
pspec = sharding_impls.prepare_axis_resources(pspec, 'pspec to array_mapping')
Expand Down
Loading