@@ -334,6 +334,109 @@ def test_transfer(self, partitioned, use_replica_parallel):
334334 # With single-replica we transfer a single slice for each shard.
335335 self .assertLen (rslices .replica_slices , num_partitions )
336336
337+ def test_nbytes_no_slicing (self ):
338+ rslice = replica_slices .ReplicaSlice (
339+ index = (slice (None ),),
340+ unsliced_data = np .zeros ((2 , 3 ), dtype = np .float32 ),
341+ slice_args = None ,
342+ )
343+ rslices = replica_slices .ReplicaSlices (
344+ global_shape = (2 , 3 ),
345+ local_shape = (2 , 3 ),
346+ sharding = jax .sharding .NamedSharding (
347+ jax .sharding .Mesh (np .array (jax .devices ()), ('x' ,)),
348+ jax .sharding .PartitionSpec (),
349+ ),
350+ dtype = np .dtype (np .float32 ),
351+ is_on_host = True ,
352+ replica_slices = [rslice ],
353+ )
354+ expected_nbytes = 2 * 3 * np .dtype (np .float32 ).itemsize
355+ self .assertEqual (rslices .nbytes , expected_nbytes )
356+
357+ def test_nbytes_with_slicing (self ):
358+ rslice = replica_slices .ReplicaSlice (
359+ index = (slice (None ),),
360+ unsliced_data = jax .numpy .zeros ((4 , 5 ), dtype = np .int16 ),
361+ slice_args = replica_slices .SliceArgs (
362+ start_index = 1 , limit_index = 3 , axis = 0
363+ ),
364+ )
365+ rslices = replica_slices .ReplicaSlices (
366+ global_shape = (4 , 5 ),
367+ local_shape = (4 , 5 ),
368+ sharding = jax .sharding .NamedSharding (
369+ jax .sharding .Mesh (np .array (jax .devices ()), ('x' ,)),
370+ jax .sharding .PartitionSpec (),
371+ ),
372+ dtype = np .dtype (np .int16 ),
373+ is_on_host = False , # Set to False to allow slice_args
374+ replica_slices = [rslice ],
375+ )
376+ # Shape of the slice is (2, 5)
377+ expected_nbytes = 2 * 5 * np .dtype (np .int16 ).itemsize
378+ self .assertEqual (rslices .nbytes , expected_nbytes )
379+
380+ def test_nbytes_multiple_slices (self ):
381+ rslice1 = replica_slices .ReplicaSlice (
382+ index = (slice (None ),),
383+ unsliced_data = jax .numpy .zeros ((4 , 5 ), dtype = np .int16 ),
384+ slice_args = replica_slices .SliceArgs (
385+ start_index = 0 , limit_index = 2 , axis = 0
386+ ),
387+ )
388+ rslice2 = replica_slices .ReplicaSlice (
389+ index = (slice (None ),),
390+ unsliced_data = jax .numpy .zeros ((4 , 5 ), dtype = np .int16 ),
391+ slice_args = replica_slices .SliceArgs (
392+ start_index = 2 , limit_index = 4 , axis = 0
393+ ),
394+ )
395+ rslices = replica_slices .ReplicaSlices (
396+ global_shape = (4 , 5 ),
397+ local_shape = (4 , 5 ),
398+ sharding = jax .sharding .NamedSharding (
399+ jax .sharding .Mesh (np .array (jax .devices ()), ('x' ,)),
400+ jax .sharding .PartitionSpec (),
401+ ),
402+ dtype = np .dtype (np .int16 ),
403+ is_on_host = False , # Set to False to allow slice_args
404+ replica_slices = [rslice1 , rslice2 ],
405+ )
406+ # Shape of each slice is (2, 5)
407+ expected_nbytes = (2 * 5 * np .dtype (np .int16 ).itemsize ) * 2
408+ self .assertEqual (rslices .nbytes , expected_nbytes )
409+
410+ def test_nbytes_mixed_slicing (self ):
411+ rslice1 = replica_slices .ReplicaSlice (
412+ index = (slice (None ),),
413+ unsliced_data = jax .numpy .zeros ((4 , 5 ), dtype = np .int16 ), # jax.Array
414+ slice_args = None ,
415+ )
416+ rslice2 = replica_slices .ReplicaSlice (
417+ index = (slice (None ),),
418+ unsliced_data = jax .numpy .zeros ((4 , 5 ), dtype = np .int16 ), # jax.Array
419+ slice_args = replica_slices .SliceArgs (
420+ start_index = 0 , limit_index = 2 , axis = 1
421+ ),
422+ )
423+ rslices = replica_slices .ReplicaSlices (
424+ global_shape = (4 , 5 ),
425+ local_shape = (4 , 5 ),
426+ sharding = jax .sharding .NamedSharding (
427+ jax .sharding .Mesh (np .array (jax .devices ()), ('x' ,)),
428+ jax .sharding .PartitionSpec (),
429+ ),
430+ dtype = np .dtype (np .int16 ),
431+ is_on_host = False , # Set to False to allow slice_args
432+ replica_slices = [rslice1 , rslice2 ],
433+ )
434+ # bytes for rslice1: (4 * 5)
435+ # bytes for rslice2: (4 * 2)
436+ expected_nbytes = (4 * 5 * np .dtype (np .int16 ).itemsize ) + (
437+ 4 * 2 * np .dtype (np .int16 ).itemsize
438+ )
439+ self .assertEqual (rslices .nbytes , expected_nbytes )
337440
338441if __name__ == '__main__' :
339442 absltest .main ()
0 commit comments