You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
TL;DRjax.experimental.multihost_utils.host_local_array_to_global_array is failing to correctly shard an array and giving a zero-size array error (even though the script runs fine locally over multiple CPU threads)
Hi all,
I'm having this weird issue in JAX when sharding arrays for multi-host GPU jobs (which run fine locally). I've managed to make a minimal reproducible example (MRE) showing the error (see below), but when I tried to host an array over multiple GPUs I get this strange 'zero-size' array error (stack trace is at the bottom of this post).
This is a MRE of a larger script which saves some data after running on multiple GPUS. In the case of 1 node and 2 GPUs, the error is shown but it saves the data. If I tried multiple GPUs (e.g. 8) or multiple GPUs over multiple nodes (e.g. 2 nodes, 8 GPUs per node), the job fatally crashes and gives an even longer stacktrace (if you want that, please do let me know!)
I've been testing out the mesh_shape by using argparse and then re-defining the mesh inside the python script.
My ideal outcome here is to have the 2d Mesh where I can distribute very large arrays (on the order of a 1e6 to 1e8 rows) over the 'model' axis of the Mesh and then run multiple copies of this sharded array over the 'batch' axis of the Mesh. For example, if I have a mesh of 4x2,and 1e6 rows. I'd have the first 500k on Mesh[:, 0] and the last 500k on Mesh[:,1] and both shards would be replicated over the rows of the Mesh. (If that makes sense).
import argparse
parser = argparse.ArgumentParser(prog='jax_sharding',
description='test script to test shard_map in JAX')
parser.add_argument('-MS', '--mesh_shape', type=str, default='-1,1', help='Specifies the formatting of the Device Mesh, defaults to shape mesh: (-1, 1)]')
args = parser.parse_args()
import math
mesh_shape = tuple(map(int, args.mesh_shape.split(',')))
print(f'Requesting mesh shape of: {mesh_shape}')
n_requested = math.prod(mesh_shape)
import os
os.environ["XLA_FLAGS"] = (
f'--xla_force_host_platform_device_count={n_requested} '
)
import jax
import jaxlib
from jax import numpy as jnp
from jax.sharding import Mesh
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map
from jax.experimental.multihost_utils import host_local_array_to_global_array
from jax.sharding import PartitionSpec as P
from jax.debug import visualize_array_sharding
if "SLURM_JOB_ID" in os.environ:
SLURM_JOB_PARTITION = os.environ['SLURM_JOB_PARTITION']
SLURM_JOB_ID = int(os.environ['SLURM_JOB_ID'])
print(f'Launching job (SLURM_JOB_ID: {SLURM_JOB_ID}) on {SLURM_JOB_PARTITION} with {n_requested} devices)')
jax.distributed.initialize(local_device_ids=list(range(n_requested)))
local_devices = jax.devices()[:n_requested]
devices = mesh_utils.create_device_mesh(mesh_shape=mesh_shape,
devices=local_devices)
devices = devices.reshape(mesh_shape)
mesh = Mesh(devices=devices, axis_names=('batch','model'))
print('mesh: ',mesh)
my_array_1d = jnp.arange(640)
my_array_2d = jnp.arange(640*24).reshape(640,24)
print('before sharding')
print('my_array_1d: ',my_array_1d.shape)
print('my_array_2d: ',my_array_2d.shape)
my_array_1d = host_local_array_to_global_array(local_inputs=my_array_1d,
global_mesh=mesh,
pspecs=P('model')) # split 1d array over 'model' axis
my_array_2d = host_local_array_to_global_array(local_inputs=my_array_2d,
global_mesh=mesh,
pspecs=P('model',None)) # shard rows over 'model' (replicate columns)
print('after sharding')
print('my_array_1d: ',my_array_1d.shape)
visualize_array_sharding(my_array_1d)
print('my_array_2d: ',my_array_2d.shape)
visualize_array_sharding(my_array_2d)
Are the proxy variables in the environment affecting the set-up here at all?
The stacktrace is here,
WARNING:2025-03-06 10:46:53,274:jax._src.distributed:125: JAX detected proxy variable(s) in the environment as distributed setup: GT_PROXY_MODE X509_USER_PROXY. On some systems, this may cause a hang
of distributed.initialize and you may need to unset these ENV variable(s)
WARNING:2025-03-06 10:46:53,276:jax._src.distributed:125: JAX detected proxy variable(s) in the environment as distributed setup: GT_PROXY_MODE X509_USER_PROXY. On some systems, this may cause a hang
of distributed.initialize and you may need to unset these ENV variable(s)
Requesting mesh shape of: (2, 1)
Launching job (SLURM_JOB_ID: 158502) on volta_compute with 2 devices)
mesh: Mesh('batch': 2, 'model': 1)
before sharding
my_array_1d: (640,)
my_array_2d: (640, 24)
Traceback (most recent call last):
File "/work/home/user/bigstick-quantum-magic-estimator/test_sharding.py", line 50, in <module>
my_array_1d = host_local_array_to_global_array(local_inputs=my_array_1d,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/work/home/user/.jax_0_4_50/lib/python3.11/site-packages/jax/experimental/multihost_utils.py", line 346, in host_local_array_to_global_array
out_flat = [
^
File "/work/home/user/.jax_0_4_50/lib/python3.11/site-packages/jax/experimental/multihost_utils.py", line 347, in <listcomp>
host_local_array_to_global_array_p.bind(inp, global_mesh=global_mesh,
File "/work/home/user/.jax_0_4_50/lib/python3.11/site-packages/jax/_src/core.py", line 502, in bind
return self._true_bind(*args, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/work/home/user/.jax_0_4_50/lib/python3.11/site-packages/jax/_src/core.py", line 520, in _true_bind
return self.bind_with_trace(prev_trace, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/work/home/user/.jax_0_4_50/lib/python3.11/site-packages/jax/_src/core.py", line 525, in bind_with_trace
return trace.process_primitive(self, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/work/home/user/.jax_0_4_50/lib/python3.11/site-packages/jax/_src/core.py", line 1024, in process_primitive
return primitive.impl(*args, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/work/home/user/.jax_0_4_50/lib/python3.11/site-packages/jax/experimental/multihost_utils.py", line 242, in host_local_array_to_global_array_impl
local_sharding = jax.sharding.NamedSharding(global_mesh.local_mesh, pspec)
^^^^^^^^^^^^^^^^^^^^^^
File "/work/home/user/.jax_0_4_50/lib/python3.11/site-packages/jax/_src/mesh.py", line 380, in local_mesh
return self._local_mesh(xb.process_index())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/work/home/user/.jax_0_4_50/lib/python3.11/site-packages/jax/_src/mesh.py", line 383, in _local_mesh
return _get_local_mesh(self, process_index)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/work/home/user/.jax_0_4_50/lib/python3.11/site-packages/jax/_src/util.py", line 302, in wrapper
return cached(config.trace_context() if trace_context_in_key else _ignore(),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/work/home/user/.jax_0_4_50/lib/python3.11/site-packages/jax/_src/util.py", line 296, in cached
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/work/home/user/.jax_0_4_50/lib/python3.11/site-packages/jax/_src/mesh.py", line 90, in _get_local_mesh
start, end = int(np.min(nonzero_indices)), int(np.max(nonzero_indices))
^^^^^^^^^^^^^^^^^^^^^^^
File "/work/home/user/.jax_0_4_50/lib/python3.11/site-packages/numpy/_core/fromnumeric.py", line 3302, in min
return _wrapreduction(a, np.minimum, 'min', axis, None, out,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/work/home/user/.jax_0_4_50/lib/python3.11/site-packages/numpy/_core/fromnumeric.py", line 86, in _wrapreduction
return ufunc.reduce(obj, axis, dtype, out, **passkwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: zero-size array to reduction operation minimum which has no identity
Requesting mesh shape of: (2, 1)
Launching job (SLURM_JOB_ID: 158502) on volta_compute with 2 devices)
mesh: Mesh('batch': 2, 'model': 1)
before sharding
my_array_1d: (640,)
my_array_2d: (640, 24)
after sharding
my_array_1d: (640,)
┌───────┐
│GPU 0,1│
└───────┘
my_array_2d: (640, 24)
┌───────┐
│ │
│ │
│ │
│ │
│GPU 0,1│
│ │
│ │
│ │
│ │
└───────┘
srun: error: v03: task 1: Exited with exit code 1
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
TL;DR
jax.experimental.multihost_utils.host_local_array_to_global_array
is failing to correctly shard an array and giving a zero-size array error (even though the script runs fine locally over multiple CPU threads)Hi all,
I'm having this weird issue in JAX when sharding arrays for multi-host GPU jobs (which run fine locally). I've managed to make a minimal reproducible example (MRE) showing the error (see below), but when I tried to host an array over multiple GPUs I get this strange 'zero-size' array error (stack trace is at the bottom of this post).
This is a MRE of a larger script which saves some data after running on multiple GPUS. In the case of 1 node and 2 GPUs, the error is shown but it saves the data. If I tried multiple GPUs (e.g. 8) or multiple GPUs over multiple nodes (e.g. 2 nodes, 8 GPUs per node), the job fatally crashes and gives an even longer stacktrace (if you want that, please do let me know!)
I've been testing out the mesh_shape by using
argparse
and then re-defining the mesh inside the python script.My ideal outcome here is to have the 2d Mesh where I can distribute very large arrays (on the order of a 1e6 to 1e8 rows) over the 'model' axis of the Mesh and then run multiple copies of this sharded array over the 'batch' axis of the Mesh. For example, if I have a mesh of 4x2,and 1e6 rows. I'd have the first 500k on Mesh[:, 0] and the last 500k on Mesh[:,1] and both shards would be replicated over the rows of the Mesh. (If that makes sense).
Thanks for the help!
Here's the versions of JAX I'm using,
Here's the minimal reproducible example,
Are the proxy variables in the environment affecting the set-up here at all?
The stacktrace is here,
Beta Was this translation helpful? Give feedback.
All reactions