@@ -699,8 +699,10 @@ def foo(x):
699699 def test_disallow_alias_copies_arrays (self ):
700700 if xla_extension_version < 296 :
701701 self .skipTest ("Requires xla_extension_version >= 296" )
702- _ , _ , _ , inp_host = _create_inputs (
703- (8 , 2 ), P ("x" , "y" ), mem_kind = "pinned_host" )
702+ mesh = jtu .create_mesh ((2 ,), ("x" ,))
703+ np_inp = np .arange (16 ).reshape (8 , 2 )
704+ s = NamedSharding (mesh , P ("x" ), memory_kind = "pinned_host" )
705+ inp_host = jax .device_put (np_inp , s )
704706
705707 inp_host_copy = jax .device_put (inp_host , may_alias = False )
706708
@@ -712,8 +714,10 @@ def test_disallow_alias_copies_arrays(self):
712714 def test_disallow_alias_copies_arrays_with_donated_input (self ):
713715 if xla_extension_version < 296 :
714716 self .skipTest ("Requires xla_extension_version >= 296" )
715- _ , _ , _ , inp_host = _create_inputs (
716- (8 , 2 ), P ("x" , "y" ), mem_kind = "pinned_host" )
717+ mesh = jtu .create_mesh ((2 ,), ("x" ,))
718+ np_inp = np .arange (16 ).reshape (8 , 2 )
719+ s = NamedSharding (mesh , P ("x" ), memory_kind = "pinned_host" )
720+ inp_host = jax .device_put (np_inp , s )
717721
718722 inp_host_donate = jax .jit (lambda x : x , donate_argnums = 0 )(inp_host )
719723
0 commit comments