@@ -64,7 +64,7 @@ func.func @main(
6464// -----
6565sdy.mesh @mesh = <[" a" =2 , " b" =2 , " c" =2 ]>
6666
67- // Test: named_computation_with_shardings
67+ // Test: call_with_shardings
6868// CHECK-LABEL: func @main
6969func.func @main (%arg0: tensor <8 x2 xi32 >, %arg1: tensor <4 x2 xi32 >) -> tensor <12 x2 xi32 > {
7070 // CHECK-NEXT: %0 = sdy.reshard %arg0 <@mesh, [{"a"}, {}]> : tensor<8x2xi32>
@@ -73,15 +73,13 @@ func.func @main(%arg0: tensor<8x2xi32>, %arg1: tensor<4x2xi32>) -> tensor<12x2xi
7373 // CHECK-NEXT: %3 = sdy.reshard %1#1 <@mesh, [{}, {"a"}]> : tensor<4x2xi32>
7474 // CHECK-NEXT: %4 = stablehlo.concatenate %2, %3, dim = 0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"a"}]>]>}
7575 // CHECK-NEXT: return %4 : tensor<12x2xi32>
76- %0:2 = sdy.named_computation <" foo" >(%arg0 , %arg1 ) in_shardings =[<@mesh , [{" a" }, {}]>, <@mesh , [{?}, {}]>] out_shardings =[<@mesh , [{" a" }, {}]>, <@mesh , [{?}, {}]>] (%arg2: tensor <8 x2 xi32 >, %arg3: tensor <4 x2 xi32 >) {
77- sdy.return %arg2 , %arg3 : tensor <8 x2 xi32 >, tensor <4 x2 xi32 >
78- } : (tensor <8 x2 xi32 >, tensor <4 x2 xi32 >) -> (tensor <8 x2 xi32 >, tensor <4 x2 xi32 >)
76+ %0:2 = call @foo (%arg0 , %arg1 ) {sdy.sharding = #sdy.sharding_per_value <[<@mesh , [{" a" }, {}]>, <@mesh , [{}, {}]>]>} : (tensor <8 x2 xi32 >, tensor <4 x2 xi32 >) -> (tensor <8 x2 xi32 >, tensor <4 x2 xi32 >)
7977 %1 = stablehlo.concatenate %0#0 , %0#1 , dim =0 {sdy.sharding = #sdy.sharding_per_value <[<@mesh , [{}, {" a" }]>]>} : (tensor <8 x2 xi32 >, tensor <4 x2 xi32 >) -> tensor <12 x2 xi32 >
8078 return %1 : tensor <12 x2 xi32 >
8179}
8280
83- // CHECK-LABEL: func private @foo(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}, %arg1: tensor<4x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>})
84- // CHECK-SAME: -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}, tensor<4x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>})
85- // CHECK-SAME: attributes {sdy.original_func_name = "foo"} {
86- // CHECK-NEXT: return %arg0, %arg1 : tensor<8x2xi32>, tensor<4x2xi32>
87- // CHECK-NEXT: }
81+ func. func private @foo (%arg0: tensor <8 x2 xi32 > {sdy.sharding = #sdy.sharding <@mesh , [{" a" }, {}]>}, %arg1: tensor <4 x2 xi32 > {sdy.sharding = #sdy.sharding <@mesh , [{}, {}]>})
82+ -> (tensor <8 x2 xi32 > {sdy.sharding = #sdy.sharding <@mesh , [{" a" }, {}]>}, tensor <4 x2 xi32 > {sdy.sharding = #sdy.sharding <@mesh , [{}, {}]>})
83+ attributes {sdy.original_func_name = " foo" } {
84+ return %arg0 , %arg1 : tensor <8 x2 xi32 >, tensor <4 x2 xi32 >
85+ }
0 commit comments