diff --git a/flax/jax_utils.py b/flax/jax_utils.py index bfe6849f3..ff9c80d31 100644 --- a/flax/jax_utils.py +++ b/flax/jax_utils.py @@ -25,6 +25,7 @@ from jax import core, lax from jax.extend import linear_util as lu from jax.interpreters import partial_eval as pe +from jax.sharding import NamedSharding, PartitionSpec as P, AxisType def _pmap_device_order(): @@ -42,7 +43,24 @@ def replicate(tree, devices=None): A new pytree containing the replicated arrays. """ devices = devices or _pmap_device_order() - return jax.device_put_replicated(tree, devices) + mesh = jax.make_mesh( + (len(devices),), + ("_flax_jax_utils_replicate_data_axis",), + (AxisType.Auto,), + devices=devices, + ) + data_sharding = NamedSharding(mesh, P("_flax_jax_utils_replicate_data_axis")) + + def _device_put_replicated(x): + if isinstance(x, (jax.Array, np.ndarray)): + buf = x[None] + else: + buf = jnp.asarray(x)[None] + buf = jnp.concat([buf] * len(devices)) + return jax.device_put(buf, data_sharding) + + with jax.set_mesh(mesh): + return jax.tree.map(_device_put_replicated, tree) def unreplicate(tree): @@ -137,12 +155,21 @@ def prefetch_to_device(iterator, size, devices=None): queue = collections.deque() devices = _pmap_device_order() if devices is None else devices + mesh = jax.make_mesh( + (len(devices),), + ("_flax_jax_utils_prefetch_to_device_data_axis",), + (AxisType.Auto,), + devices=devices, + ) + data_sharding = NamedSharding(mesh, P("_flax_jax_utils_prefetch_to_device_data_axis")) + def _prefetch(xs): - return jax.device_put_sharded(list(xs), devices) + return jax.device_put(xs, data_sharding) def enqueue(n): # Enqueues *up to* `n` elements from the iterator. - for data in itertools.islice(iterator, n): - queue.append(jax.tree_util.tree_map(_prefetch, data)) + with jax.set_mesh(mesh): + for data in itertools.islice(iterator, n): + queue.append(jax.tree_util.tree_map(_prefetch, data)) enqueue(size) # Fill up the buffer. while queue: diff --git a/pyproject.toml b/pyproject.toml index ebf5dcd90..270a7ea28 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,8 @@ testing = [ "tensorflow_datasets", "tensorflow>=2.12.0; python_version<'3.13'", # to fix Numpy np.bool8 deprecation error "tensorflow>=2.20.0; python_version>='3.13'", + # Temporary fix for https://github.com/google/flax/issues/5143 + "keras<3.13", "torch", "treescope>=0.1.1; python_version>='3.10'", "cloudpickle>=3.0.0", diff --git a/tests/jax_utils_test.py b/tests/jax_utils_test.py index c2130e700..feb636e1a 100644 --- a/tests/jax_utils_test.py +++ b/tests/jax_utils_test.py @@ -106,5 +106,64 @@ def add(params, a, *, b): np.testing.assert_allclose(np.float64(y), np.float64(5 * x + 10)) +class DataShardingTest(parameterized.TestCase): + def setUp(self): + if jax.device_count() < 4: + self.skipTest('At least 4 devices required') + + @parameterized.product(num_devices= ["all", 2]) + def test_prefetch_to_device(self, num_devices): + devices = jax.local_devices() + if isinstance(num_devices, int): + devices = devices[:num_devices] + shape = (len(devices), 4, 16, 16, 3) + iterator = (jnp.ones(shape) for _ in range(4)) + + data_iter = jax_utils.prefetch_to_device(iterator, size=3, devices=devices) + for _ in range(4): + data = next(data_iter) + self.assertEqual(data.shape, shape) + self.assertIsNotNone(data.sharding) + sharding_slices_per_device = data.sharding.devices_indices_map(tuple(data.shape)) + self.assertEqual(len(sharding_slices_per_device), len(devices)) + # Here we check that sharding_slices_per_device is like + # Device(id=2): (slice(2, 3, None), slice(None, None, None), ..., slice(None, None, None)) + for i, dev in enumerate(devices): + sharding_slice = sharding_slices_per_device[dev] + self.assertEqual(sharding_slice[0], slice(i + 0, i + 1, None)) + for sharding_slice_j in sharding_slice[1:]: + self.assertEqual(sharding_slice_j, slice(None, None, None)) + + @parameterized.product(num_devices= ["all", 2]) + def test_replicate(self, num_devices): + devices = jax.local_devices() + if isinstance(num_devices, int): + devices = devices[:num_devices] + num_batches = 5 + shape = (2, 3) + data_tree = [ + i * jnp.ones((2, 3)) for i in range(num_batches - 2) + ] + [4, 5 * np.ones(shape)] + out_tree = jax_utils.replicate(data_tree, devices=devices) + + def check_sharding(p): + if p.ndim == 1: + self.assertEqual(p.shape, (len(devices),)) + else: + self.assertEqual(p.shape, (len(devices), *shape)) + self.assertIsNotNone(p.sharding) + sharding_slices_per_device = p.sharding.devices_indices_map(tuple(p.shape)) + self.assertEqual(len(sharding_slices_per_device), len(devices)) + # Here we check that sharding_slices_per_device is like + # Device(id=2): (slice(2, 3, None), slice(None, None, None), slice(None, None, None)) + for i, dev in enumerate(devices): + sharding_slice = sharding_slices_per_device[dev] + self.assertEqual(sharding_slice[0], slice(i + 0, i + 1, None)) + for sharding_slice_j in sharding_slice[1:]: + self.assertEqual(sharding_slice_j, slice(None, None, None)) + + jax.tree.map(check_sharding, out_tree) + + if __name__ == '__main__': absltest.main()