@@ -88,6 +88,30 @@ func.func @manual_computation(%arg0: tensor<208xf32> {sdy.sharding = #sdy.shardi
8888 return %0 : tensor <208 xf32 >
8989}
9090
91+ // CHECK-LABEL: func @reduce_multiple_results
92+ func.func @reduce_multiple_results (
93+ %arg0: tensor <2 x64 x13 xf32 > {sdy.sharding = #sdy.sharding <@mesh , [{" x" }, {}, {" y" }]>},
94+ %arg1: tensor <2 x64 x13 xi32 > {sdy.sharding = #sdy.sharding <@mesh , [{" x" }, {}, {" y" }]>})
95+ -> (tensor <64 xf32 > {sdy.sharding = #sdy.sharding <@mesh , [{}], unreduced ={" y" }>},
96+ tensor <64 xi32 > {sdy.sharding = #sdy.sharding <@mesh , [{}], unreduced ={" y" }>}) {
97+ %0 = stablehlo.constant dense <0.000000e+00 > : tensor <f32 >
98+ %1 = stablehlo.constant dense <0 > : tensor <i32 >
99+ // CHECK: %[[REDUCE:.*]]:2 = stablehlo.reduce(%arg0 init: %cst), (%arg1 init: %c) across dimensions = [0, 2]
100+ // CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}], unreduced={"y"}>, <@mesh, [{}], unreduced={"y"}>]>}
101+ // CHECK: %[[ALL_REDUCE0:.*]] = sdy.all_reduce {"x"} %[[REDUCE]]#0 out_sharding=<@mesh, [{}], unreduced={"y"}> : tensor<64xf32>
102+ // CHECK-NEXT: %[[ALL_REDUCE1:.*]] = sdy.all_reduce {"x"} %[[REDUCE]]#1 out_sharding=<@mesh, [{}], unreduced={"y"}> : tensor<64xi32>
103+ // CHECK-NEXT: return %[[ALL_REDUCE0]], %[[ALL_REDUCE1]] : tensor<64xf32>, tensor<64xi32>
104+ %2:2 = stablehlo.reduce (%arg0 init : %0 ), (%arg1 init : %1 ) across dimensions = [0 , 2 ]
105+ {sdy.sharding = #sdy.sharding_per_value <[<@mesh , [{}], unreduced ={" y" }>, <@mesh , [{}], unreduced ={" y" }>]>} :
106+ (tensor <2 x64 x13 xf32 >, tensor <2 x64 x13 xi32 >, tensor <f32 >, tensor <i32 >) -> (tensor <64 xf32 >, tensor <64 xi32 >)
107+ reducer (%arg2: tensor <f32 >, %arg4: tensor <f32 >) (%arg3: tensor <i32 >, %arg5: tensor <i32 >) {
108+ %3 = stablehlo.add %arg2 , %arg4 : tensor <f32 >
109+ %4 = stablehlo.add %arg3 , %arg5 : tensor <i32 >
110+ stablehlo.return %3 , %4 : tensor <f32 >, tensor <i32 >
111+ }
112+ return %2#0 , %2#1 : tensor <64 xf32 >, tensor <64 xi32 >
113+ }
114+
91115//===----------------------------------------------------------------------===//
92116// Dot tests
93117//===----------------------------------------------------------------------===//
0 commit comments