Skip to content

Commit 42f9d74

Browse files
ZixuanJiangcopybara-github
authored andcommitted
Generate correct reduction axes in the minimal version of explicit reshards.
PiperOrigin-RevId: 811498978
1 parent d4f7057 commit 42f9d74

2 files changed

Lines changed: 31 additions & 10 deletions

File tree

shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ bool isOnFullVersion(Operation* op, const bool enableFullVersion) {
395395
// - All op results have the same unreduced axes.
396396
// - If the op has no results, none of the operands has unreduced axes.
397397
//
398-
// Returns the union of common reducation axes which may not be canonicalized.
398+
// Returns the union of common reduction axes which may not be canonicalized.
399399
SmallVector<AxisRefAttr> processOp(Operation* op,
400400
ArrayRef<TensorShardingAttr> inShardings,
401401
ArrayRef<TensorShardingAttr> outShardings,
@@ -427,32 +427,29 @@ SmallVector<AxisRefAttr> processOp(Operation* op,
427427
/*closedIfMissing=*/true);
428428
// TODO(enver): Factor out finding common axes per factor. Share logic with
429429
// getCompatibleFactorShardings.
430-
SmallVector<AxisRefAttr> reductionAxes;
431-
AxesPerFactor commonAxesPerFactor(shardingRule.getNumFactors());
430+
SmallVector<AxisRefAttr> axesAlongAllReductionFactors;
432431
for (int64_t reductionFactor : shardingRule.getReductionFactors()) {
433432
// We only iterate operands since reduction factors are not in results.
434433
bool seen = false;
435-
SmallVector<AxisRefAttr>& commonAxes = commonAxesPerFactor[reductionFactor];
434+
SmallVector<AxisRefAttr> axesAlongCurrentReductionFactor;
436435
for (const TensorFactorShardings& tensorFactorSharding :
437436
shardingProjection.getOperands()) {
438437
if (std::optional<ArrayRef<AxisRefAttr>> factorSharding =
439438
getFactorSharding(tensorFactorSharding, reductionFactor)) {
440-
SmallVector<AxisRefAttr> factorShardingVector =
441-
llvm::to_vector(*factorSharding);
442439
if (seen) {
443-
SDY_CHECK(factorShardingVector == commonAxes)
440+
SDY_CHECK(axesAlongCurrentReductionFactor == *factorSharding)
444441
<< "For the operation " << op
445442
<< ", the result has unreduced axes while the operand has "
446443
"incompatible sharding along reduction factors.";
447444
} else {
448-
commonAxes = factorShardingVector;
445+
axesAlongCurrentReductionFactor = llvm::to_vector(*factorSharding);
449446
seen = true;
450447
}
451-
reductionAxes.append(commonAxes);
452448
}
453449
}
450+
axesAlongAllReductionFactors.append(axesAlongCurrentReductionFactor);
454451
}
455-
return reductionAxes;
452+
return axesAlongAllReductionFactors;
456453
}
457454

458455
struct InsertExplicitReshardsPass

shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,30 @@ func.func @manual_computation(%arg0: tensor<208xf32> {sdy.sharding = #sdy.shardi
8888
return %0 : tensor<208xf32>
8989
}
9090

91+
// CHECK-LABEL: func @reduce_multiple_results
92+
func.func @reduce_multiple_results(
93+
%arg0: tensor<2x64x13xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}, {"y"}]>},
94+
%arg1: tensor<2x64x13xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}, {"y"}]>})
95+
-> (tensor<64xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}], unreduced={"y"}>},
96+
tensor<64xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}], unreduced={"y"}>}) {
97+
%0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
98+
%1 = stablehlo.constant dense<0> : tensor<i32>
99+
// CHECK: %[[REDUCE:.*]]:2 = stablehlo.reduce(%arg0 init: %cst), (%arg1 init: %c) across dimensions = [0, 2]
100+
// CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}], unreduced={"y"}>, <@mesh, [{}], unreduced={"y"}>]>}
101+
// CHECK: %[[ALL_REDUCE0:.*]] = sdy.all_reduce {"x"} %[[REDUCE]]#0 out_sharding=<@mesh, [{}], unreduced={"y"}> : tensor<64xf32>
102+
// CHECK-NEXT: %[[ALL_REDUCE1:.*]] = sdy.all_reduce {"x"} %[[REDUCE]]#1 out_sharding=<@mesh, [{}], unreduced={"y"}> : tensor<64xi32>
103+
// CHECK-NEXT: return %[[ALL_REDUCE0]], %[[ALL_REDUCE1]] : tensor<64xf32>, tensor<64xi32>
104+
%2:2 = stablehlo.reduce(%arg0 init: %0), (%arg1 init: %1) across dimensions = [0, 2]
105+
{sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}], unreduced={"y"}>, <@mesh, [{}], unreduced={"y"}>]>} :
106+
(tensor<2x64x13xf32>, tensor<2x64x13xi32>, tensor<f32>, tensor<i32>) -> (tensor<64xf32>, tensor<64xi32>)
107+
reducer(%arg2: tensor<f32>, %arg4: tensor<f32>) (%arg3: tensor<i32>, %arg5: tensor<i32>) {
108+
%3 = stablehlo.add %arg2, %arg4 : tensor<f32>
109+
%4 = stablehlo.add %arg3, %arg5 : tensor<i32>
110+
stablehlo.return %3, %4 : tensor<f32>, tensor<i32>
111+
}
112+
return %2#0, %2#1 : tensor<64xf32>, tensor<64xi32>
113+
}
114+
91115
//===----------------------------------------------------------------------===//
92116
// Dot tests
93117
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)