@@ -17,6 +17,7 @@ func.func @main(%arg0: tensor<32x96xf32>) -> tensor<32x96xf32> {
1717
1818// -----
1919
20+
2021// CHECK-LABEL: func @main
2122func.func @main (%arg0: tensor <8 x8 xf32 >, %arg1: tensor <8 x8 xf32 >) {
2223 // CHECK-DAG: sdy.sharding_group %arg0 group_id=0 : tensor<8x8xf32>
@@ -30,10 +31,11 @@ func.func @main(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32>) {
3031
3132// -----
3233
34+
3335sdy.mesh @mesh = <[" c" =2 , " a" =2 , " b" =2 ]>
3436
35- // CHECK-LABEL: @add_manual_axes_to_replicated
36- func.func @add_manual_axes_to_replicated (%arg0: tensor <8 xf32 >) -> tensor <8 xf32 > {
37+ // CHECK-LABEL: func @main
38+ func.func @main (%arg0: tensor <8 xf32 >) -> tensor <8 xf32 > {
3739 // CHECK-NEXT: sdy.manual_computation(%arg0)
3840 // CHECK-SAME{LITERAL}: in_shardings=[<@mesh, [{"c", ?}], replicated={"a"}>]
3941 // CHECK-SAME{LITERAL}: out_shardings=[<@mesh, [{"c", ?}], replicated={"a"}>]
@@ -46,14 +48,15 @@ func.func @add_manual_axes_to_replicated(%arg0: tensor<8xf32>) -> tensor<8xf32>
4648
4749// -----
4850
51+
4952sdy.mesh @mesh = <[" c" =2 , " a" =2 , " b" =2 ]>
5053
5154// Due to the in_sharding being fully closed, the in_sharding is added to the
5255// func arg but with the manual axis added as replicated.
53- // CHECK-LABEL: @add_manual_axes_to_replicated_applied_constraint
56+ // CHECK-LABEL: func @main
5457// CHECK-SAME %arg0: tensor<16x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c"}, {}], replicated={"a"}>}
5558// CHECK-SAME -> tensor<16x16xf32> {
56- func.func @add_manual_axes_to_replicated_applied_constraint (%arg0: tensor <8 xf32 >) -> tensor <8 xf32 > {
59+ func.func @main (%arg0: tensor <8 xf32 >) -> tensor <8 xf32 > {
5760 %0 = sdy.manual_computation (%arg0 ) in_shardings =[<@mesh , [{" c" }]>] out_shardings =[<@mesh , [{" c" }]>] manual_axes ={" c" , " a" } (%arg1: tensor <4 xf32 >) {
5861 %1 = stablehlo.add %arg1 , %arg1 : tensor <4 xf32 >
5962 sdy.return %1 : tensor <4 xf32 >
@@ -63,27 +66,31 @@ func.func @add_manual_axes_to_replicated_applied_constraint(%arg0: tensor<8xf32>
6366
6467// -----
6568
69+
6670sdy.mesh @mesh = <[" a" =2 ]>
6771
6872// This test verifies that the manual axes are cleaned up before adding data
6973// flow edges.
70- func.func @manual_axes_cleanup_before_adding_data_flow_edges (%arg0: tensor <8 xf32 >) -> tensor <8 xf32 > {
74+ func.func @main (%arg0: tensor <8 xf32 >) -> tensor <8 xf32 > {
7175 %0 = sdy.manual_computation (%arg0 ) in_shardings =[<@mesh , [{?}]>] out_shardings =[<@mesh , [{?}]>] manual_axes ={" a" } (%arg1: tensor <8 xf32 >) {
7276 %1 = stablehlo.add %arg1 , %arg1 : tensor <8 xf32 >
7377 sdy.return %1 : tensor <8 xf32 >
7478 } : (tensor <8 xf32 >) -> tensor <8 xf32 >
75- // CHECK: sdy.data_flow_edge %0 sharding=<@mesh, [{? }], replicated={"a"}> : tensor<8xf32>
79+ // CHECK: sdy.data_flow_edge %0 sharding=<@mesh, [{"c" }], replicated={"a"}> : tensor<8xf32>
7680 return %0 : tensor <8 xf32 >
7781}
7882
7983// -----
8084
85+
8186sdy.mesh @mesh = <[" a" =2 ]>
8287
83- // CHECK-LABEL: func @single_call
84- func.func @single_call (%arg0: tensor <8 xf32 >) -> tensor <8 xf32 > {
85- // CHECK-NEXT: %0 = call @foo(%arg0) : (tensor<8xf32>) -> tensor<8xf32>
86- // CHECK-NEXT: return %0 : tensor<8xf32>
88+ // test: single_call
89+ // CHECK-LABEL: func @main
90+ func.func @main (%arg0: tensor <8 xf32 >) -> tensor <8 xf32 > {
91+ // CHECK: %[[CALL:.*]] = call @foo(%arg0) : (tensor<8xf32>) -> tensor<8xf32>
92+ // CHECK: %[[EDGE:.*]] = sdy.func_data_flow_edge %[[CALL]] : tensor<8xf32>
93+ // CHECK: return %[[EDGE]] : tensor<8xf32>
8794 %0 = call @foo (%arg0 ) : (tensor <8 xf32 >) -> tensor <8 xf32 >
8895 return %0 : tensor <8 xf32 >
8996}
0 commit comments