Skip to content

Commit f20d857

Browse files
author
Orbax Authors
committed
Fix nbytes calculation for ReplicaSlices to consider slices.
PiperOrigin-RevId: 794485404
1 parent b7f7247 commit f20d857

File tree

2 files changed

+119
-2
lines changed

2 files changed

+119
-2
lines changed

checkpoint/orbax/checkpoint/_src/serialization/replica_slices.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,22 @@ def __post_init__(self):
115115

116116
@property
117117
def nbytes(self) -> int:
118-
slice_nbytes = math.prod(self.local_shape) * self.dtype.itemsize
119-
return slice_nbytes * len(self.replica_slices)
118+
"""Returns the total number of bytes for all replica slices."""
119+
total_bytes = 0
120+
for rslice in self.replica_slices:
121+
if rslice.slice_args is None:
122+
# No slicing, use the full local_shape
123+
slice_nbytes = math.prod(self.local_shape) * self.dtype.itemsize
124+
else:
125+
# Replica-parallel, calculate bytes for the specific slice
126+
slice_shape = list(self.local_shape)
127+
axis = rslice.slice_args.axis
128+
slice_shape[axis] = (
129+
rslice.slice_args.limit_index - rslice.slice_args.start_index
130+
)
131+
slice_nbytes = math.prod(slice_shape) * self.dtype.itemsize
132+
total_bytes += slice_nbytes
133+
return total_bytes
120134

121135
def to_fragments(self) -> fragments.Fragments:
122136
"""Converts replica slices to fragments."""

checkpoint/orbax/checkpoint/_src/serialization/replica_slices_test.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,109 @@ def test_transfer(self, partitioned, use_replica_parallel):
334334
# With single-replica we transfer a single slice for each shard.
335335
self.assertLen(rslices.replica_slices, num_partitions)
336336

337+
def test_nbytes_no_slicing(self):
338+
rslice = replica_slices.ReplicaSlice(
339+
index=(slice(None),),
340+
unsliced_data=np.zeros((2, 3), dtype=np.float32),
341+
slice_args=None,
342+
)
343+
rslices = replica_slices.ReplicaSlices(
344+
global_shape=(2, 3),
345+
local_shape=(2, 3),
346+
sharding=jax.sharding.NamedSharding(
347+
jax.sharding.Mesh(np.array(jax.devices()), ('x',)),
348+
jax.sharding.PartitionSpec(),
349+
),
350+
dtype=np.dtype(np.float32),
351+
is_on_host=True,
352+
replica_slices=[rslice],
353+
)
354+
expected_nbytes = 2 * 3 * np.dtype(np.float32).itemsize
355+
self.assertEqual(rslices.nbytes, expected_nbytes)
356+
357+
def test_nbytes_with_slicing(self):
358+
rslice = replica_slices.ReplicaSlice(
359+
index=(slice(None),),
360+
unsliced_data=jax.numpy.zeros((4, 5), dtype=np.int16),
361+
slice_args=replica_slices.SliceArgs(
362+
start_index=1, limit_index=3, axis=0
363+
),
364+
)
365+
rslices = replica_slices.ReplicaSlices(
366+
global_shape=(4, 5),
367+
local_shape=(4, 5),
368+
sharding=jax.sharding.NamedSharding(
369+
jax.sharding.Mesh(np.array(jax.devices()), ('x',)),
370+
jax.sharding.PartitionSpec(),
371+
),
372+
dtype=np.dtype(np.int16),
373+
is_on_host=False, # Set to False to allow slice_args
374+
replica_slices=[rslice],
375+
)
376+
# Shape of the slice is (2, 5)
377+
expected_nbytes = 2 * 5 * np.dtype(np.int16).itemsize
378+
self.assertEqual(rslices.nbytes, expected_nbytes)
379+
380+
def test_nbytes_multiple_slices(self):
381+
rslice1 = replica_slices.ReplicaSlice(
382+
index=(slice(None),),
383+
unsliced_data=jax.numpy.zeros((4, 5), dtype=np.int16),
384+
slice_args=replica_slices.SliceArgs(
385+
start_index=0, limit_index=2, axis=0
386+
),
387+
)
388+
rslice2 = replica_slices.ReplicaSlice(
389+
index=(slice(None),),
390+
unsliced_data=jax.numpy.zeros((4, 5), dtype=np.int16),
391+
slice_args=replica_slices.SliceArgs(
392+
start_index=2, limit_index=4, axis=0
393+
),
394+
)
395+
rslices = replica_slices.ReplicaSlices(
396+
global_shape=(4, 5),
397+
local_shape=(4, 5),
398+
sharding=jax.sharding.NamedSharding(
399+
jax.sharding.Mesh(np.array(jax.devices()), ('x',)),
400+
jax.sharding.PartitionSpec(),
401+
),
402+
dtype=np.dtype(np.int16),
403+
is_on_host=False, # Set to False to allow slice_args
404+
replica_slices=[rslice1, rslice2],
405+
)
406+
# Shape of each slice is (2, 5)
407+
expected_nbytes = (2 * 5 * np.dtype(np.int16).itemsize) * 2
408+
self.assertEqual(rslices.nbytes, expected_nbytes)
409+
410+
def test_nbytes_mixed_slicing(self):
411+
rslice1 = replica_slices.ReplicaSlice(
412+
index=(slice(None),),
413+
unsliced_data=jax.numpy.zeros((4, 5), dtype=np.int16), # jax.Array
414+
slice_args=None,
415+
)
416+
rslice2 = replica_slices.ReplicaSlice(
417+
index=(slice(None),),
418+
unsliced_data=jax.numpy.zeros((4, 5), dtype=np.int16), # jax.Array
419+
slice_args=replica_slices.SliceArgs(
420+
start_index=0, limit_index=2, axis=1
421+
),
422+
)
423+
rslices = replica_slices.ReplicaSlices(
424+
global_shape=(4, 5),
425+
local_shape=(4, 5),
426+
sharding=jax.sharding.NamedSharding(
427+
jax.sharding.Mesh(np.array(jax.devices()), ('x',)),
428+
jax.sharding.PartitionSpec(),
429+
),
430+
dtype=np.dtype(np.int16),
431+
is_on_host=False, # Set to False to allow slice_args
432+
replica_slices=[rslice1, rslice2],
433+
)
434+
# bytes for rslice1: (4 * 5)
435+
# bytes for rslice2: (4 * 2)
436+
expected_nbytes = (4 * 5 * np.dtype(np.int16).itemsize) + (
437+
4 * 2 * np.dtype(np.int16).itemsize
438+
)
439+
self.assertEqual(rslices.nbytes, expected_nbytes)
337440

338441
if __name__ == '__main__':
339442
absltest.main()

0 commit comments

Comments
 (0)