@@ -943,9 +943,13 @@ def test_cross_host_transfer_cpu_error(self):
943943 ValueError , "does not support cross-host device transfers" ):
944944 jax .device_put (y , dst_sharding )
945945
946+ @parameterized .named_parameters (
947+ ("numpy" , np .arange ),
948+ ("uncommitted" , jnp .arange ),
949+ )
946950 @jtu .skip_on_devices ("cpu" )
947- def test_cross_host_transfer_single_device_sharding (self ):
948- x = np . arange (64 ).reshape (8 , 8 )
951+ def test_cross_host_transfer_single_device_sharding (self , arange_fn ):
952+ x = arange_fn (64 ).reshape (8 , 8 )
949953 src_pid = 0
950954 dst_pid = 1
951955 src_sharding = jax .sharding .SingleDeviceSharding (
@@ -960,9 +964,13 @@ def test_cross_host_transfer_single_device_sharding(self):
960964 else :
961965 self .assertEmpty (z .addressable_shards )
962966
967+ @parameterized .named_parameters (
968+ ("numpy" , np .arange ),
969+ ("uncommitted" , jnp .arange ),
970+ )
963971 @jtu .skip_on_devices ("cpu" )
964- def test_cross_host_transfer_named_sharding (self ):
965- x = np . arange (64 ).reshape (8 , 8 )
972+ def test_cross_host_transfer_named_sharding (self , arange_fn ):
973+ x = arange_fn (64 ).reshape (8 , 8 )
966974 n_local = jax .local_device_count ()
967975 src_pid = 0
968976 dst_pid = 1
@@ -1113,11 +1121,15 @@ def test_device_put_with_mixed_local_and_remote_transfers(self):
11131121 for shard in z .addressable_shards :
11141122 np .testing .assert_array_equal (shard .data , x [shard .index ])
11151123
1124+ @parameterized .named_parameters (
1125+ ("numpy" , np .arange ),
1126+ ("uncommitted" , jnp .arange ),
1127+ )
11161128 @jtu .skip_on_devices ("cpu" )
1117- def test_device_put_to_device (self ):
1129+ def test_device_put_to_device (self , arange_fn ):
11181130 if jaxlib_extension_version < 400 :
11191131 self .skipTest ("This functionality is not yet supported in jaxlib." )
1120- x = np . arange (64 ).reshape (8 , 8 )
1132+ x = arange_fn (64 ).reshape (8 , 8 )
11211133 src_pid = 0
11221134 dst_pid = 1
11231135 src_device = jax .local_devices (process_index = src_pid )[0 ]
0 commit comments