Skip to content

Commit 5a0bd47

Browse files
JustinPan-googOrbax Authors
authored andcommitted
Fix slice_in_dim device incompatibility with global mesh
PiperOrigin-RevId: 831039670
1 parent 8661033 commit 5a0bd47

File tree

3 files changed

+57
-9
lines changed

3 files changed

+57
-9
lines changed

checkpoint/orbax/checkpoint/_src/serialization/replica_slices.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class SliceArgs:
4646
Intended to be passed to `ReplicaSlice` in order to select a slice of
4747
the `unsliced_data` array.
4848
"""
49+
4950
start_index: int
5051
limit_index: int
5152
axis: int
@@ -81,15 +82,28 @@ def is_on_host(self):
8182
return isinstance(self.unsliced_data, np.ndarray)
8283

8384
def data(self):
85+
"""Returns the sliced data for this replica.
86+
87+
If `slice_args` is None, the entire `unsliced_data` is returned. Otherwise,
88+
a slice of `unsliced_data` is returned based on `slice_args`.
89+
"""
8490
if self.slice_args is None:
8591
return self.unsliced_data
8692
else:
87-
return jax.lax.slice_in_dim(
88-
self.unsliced_data,
89-
start_index=self.slice_args.start_index,
90-
limit_index=self.slice_args.limit_index,
91-
axis=self.slice_args.axis,
93+
# If a global mesh is set, slice_in_dim can fail with incompatible device
94+
# errors. To avoid this, we temporarily set a mesh constructed from
95+
# array's devices.
96+
mesh = jax.sharding.Mesh(
97+
np.array(list(self.unsliced_data.sharding.device_set)), ('data',)
9298
)
99+
with jax.sharding.set_mesh(mesh):
100+
sliced_data = jax.lax.slice_in_dim(
101+
self.unsliced_data,
102+
start_index=self.slice_args.start_index,
103+
limit_index=self.slice_args.limit_index,
104+
axis=self.slice_args.axis,
105+
)
106+
return sliced_data
93107

94108

95109
@dataclasses.dataclass(frozen=True)
@@ -421,10 +435,7 @@ def use_pinned_host_transfer(device: jax.Device):
421435
has_pinned_host = any(
422436
m.kind == 'pinned_host' for m in device.addressable_memories()
423437
)
424-
return (
425-
enable_pinned_host_transfer
426-
and has_pinned_host
427-
)
438+
return enable_pinned_host_transfer and has_pinned_host
428439

429440
def async_transfer_slice(
430441
rslice: ReplicaSlice,

checkpoint/orbax/checkpoint/checkpoint_manager_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3907,6 +3907,28 @@ def test_partial_restore_with_omission(self):
39073907
)
39083908
test_utils.assert_tree_equal(self, expected, restored)
39093909

3910+
@parameterized.parameters(True, False)
3911+
def test_save_with_global_mesh(self, use_same_mesh: bool):
3912+
if use_same_mesh:
3913+
devices = np.asarray(jax.devices())
3914+
axis_names = ('x',)
3915+
else:
3916+
if multihost.is_pathways_backend():
3917+
self.skipTest('Not applicable to Pathways.')
3918+
devices = np.asarray(jax.devices()[:4])
3919+
axis_names = ('x',)
3920+
mesh = jax.sharding.Mesh(devices, axis_names)
3921+
jax.sharding.set_mesh(mesh)
3922+
3923+
with CheckpointManager(
3924+
self.directory,
3925+
item_names=('params',),
3926+
) as manager:
3927+
self.assertTrue(self.save_params(0, manager, self.pytree))
3928+
self.wait_if_async(manager)
3929+
restored = self.restore_params(0, manager)
3930+
test_utils.assert_tree_equal(self, self.pytree, restored)
3931+
39103932

39113933
if __name__ == '__main__':
39123934
multiprocess_test.main()

checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -926,6 +926,21 @@ def test_partial_restore_omission(self):
926926

927927
test_utils.assert_tree_equal(self, expected, loaded)
928928

929+
@parameterized.parameters(True, False)
930+
def test_save_with_global_mesh(self, use_same_mesh: bool):
931+
if use_same_mesh:
932+
devices = np.asarray(jax.devices())
933+
axis_names = ('x',)
934+
else:
935+
devices = np.asarray(jax.devices()[:4])
936+
axis_names = ('x',)
937+
mesh = jax.sharding.Mesh(devices, axis_names)
938+
jax.sharding.set_mesh(mesh)
939+
940+
ocp.save_pytree(self.directory, self.pytree)
941+
loaded = ocp.load_pytree(self.directory, self.abstract_pytree)
942+
test_utils.assert_tree_equal(self, self.pytree, loaded)
943+
929944
@parameterized.parameters((3,), (8,))
930945
def test_primary_host_background_error(self, timeout):
931946
def _assert_false(*args, **kwargs):

0 commit comments

Comments
 (0)