Skip to content

Commit 49a123b

Browse files
Confine logics to find reduction axes and to check for unreduced axes inside the method that inserts all reduces.
It is a refactoring. PiperOrigin-RevId: 812766518
1 parent b75252b commit 49a123b

3 files changed

Lines changed: 118 additions & 93 deletions

File tree

shardy/dialect/sdy/transforms/export/explicit_reshards_util.cc

Lines changed: 61 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -804,14 +804,60 @@ SmallVector<int64_t> getTensorSizes(Operation* op) {
804804
return tensorSizes;
805805
}
806806

807-
SmallVector<AxisRefAttr> getReductionAxes(const AxesPerFactor& axesPerFactor,
808-
OpShardingRuleAttr shardingRule) {
809-
SmallVector<AxisRefAttr> reductionAxes;
807+
namespace {
808+
809+
// Returns reduction axes that are the union of all axes on reduction factors.
810+
// The result axes are not necessarilly canonicalized.
811+
//
812+
// Returns empty axes if not `onFullVersion` and op results do not have
813+
// unreduced axes.
814+
//
815+
// Assumes `commonAxesPerFactor` is non-empty if `onFullVersion` is true.
816+
//
817+
// Hard fails if some reduction factors do not have compatible shardings.
818+
SmallVector<AxisRefAttr> getReductionAxes(
819+
Operation* op, const ShardingProjection& shardingProjection,
820+
const AxesPerFactor& commonAxesPerFactor, OpShardingRuleAttr shardingRule,
821+
const bool onFullVersion) {
822+
if (onFullVersion) {
823+
SmallVector<AxisRefAttr> reductionAxes;
824+
for (int64_t reductionFactor : shardingRule.getReductionFactors()) {
825+
reductionAxes.append(commonAxesPerFactor[reductionFactor]);
826+
}
827+
return reductionAxes;
828+
}
829+
830+
if (getUnreducedAxes(op->getResult(0)).empty()) {
831+
return {};
832+
}
833+
834+
// TODO(enver): Repurpose getCompatibleFactorShardings to return compatible
835+
// factors, and simplify the following logic.
836+
SmallVector<AxisRefAttr> axesAlongAllReductionFactors;
810837
for (int64_t reductionFactor : shardingRule.getReductionFactors()) {
811-
reductionAxes.append(axesPerFactor[reductionFactor]);
838+
// We only iterate operands since reduction factors are not in results.
839+
bool seen = false;
840+
SmallVector<AxisRefAttr> axesAlongCurrentReductionFactor;
841+
for (const TensorFactorShardings& tensorFactorSharding :
842+
shardingProjection.getOperands()) {
843+
if (std::optional<ArrayRef<AxisRefAttr>> factorSharding =
844+
getFactorSharding(tensorFactorSharding, reductionFactor)) {
845+
if (seen) {
846+
SDY_CHECK(axesAlongCurrentReductionFactor == *factorSharding)
847+
<< "For the operation " << op
848+
<< ", the result has unreduced axes while the operand has "
849+
"incompatible sharding along reduction factors.";
850+
} else {
851+
axesAlongCurrentReductionFactor = llvm::to_vector(*factorSharding);
852+
seen = true;
853+
}
854+
}
855+
}
856+
axesAlongAllReductionFactors.append(axesAlongCurrentReductionFactor);
812857
}
813-
return reductionAxes;
858+
return axesAlongAllReductionFactors;
814859
}
860+
} // namespace
815861

816862
TensorShardingAttr insertAllReduceIfUnreducedToReplicated(
817863
OpOperand& use, TensorShardingAttr sourceSharding,
@@ -869,11 +915,16 @@ ArrayRef<AxisRefAttr> getUnreducedAxes(Value value) {
869915
return getUnreducedAxes(getSharding(value));
870916
}
871917

872-
void insertAllReducesForReductionFactors(Operation* op,
873-
ArrayRef<AxisRefAttr> reductionAxes,
874-
const Mesh& mesh,
875-
IRRewriter& rewriter) {
876-
if (reductionAxes.empty() || op->getResults().empty()) {
918+
void insertAllReducesForReductionFactors(
919+
Operation* op, const ShardingProjection& shardingProjection,
920+
AxesPerFactor& commonAxesPerFactor, OpShardingRuleAttr shardingRule,
921+
const Mesh& mesh, IRRewriter& rewriter, const bool onFullVersion) {
922+
if (op->getResults().empty()) {
923+
return;
924+
}
925+
SmallVector<AxisRefAttr> reductionAxes = getReductionAxes(
926+
op, shardingProjection, commonAxesPerFactor, shardingRule, onFullVersion);
927+
if (reductionAxes.empty()) {
877928
return;
878929
}
879930

shardy/dialect/sdy/transforms/export/explicit_reshards_util.h

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,6 @@ ArrayRef<AxisRefAttr> getUnreducedAxes(Value value);
8181
// Returns a concatenated array of operand and result tensor sizes.
8282
SmallVector<int64_t> getTensorSizes(Operation* op);
8383

84-
// Returns reduction axes that are the union of all axes on reduction factors.
85-
// The result axes are not necessarilly canonicalized.
86-
SmallVector<AxisRefAttr> getReductionAxes(const AxesPerFactor& axesPerFactor,
87-
OpShardingRuleAttr shardingRule);
88-
8984
// Returns true iff any tensor factor sharding has non-empty overflow axes.
9085
bool hasOverflowAxes(const ShardingProjection& shardingProjection);
9186

@@ -147,14 +142,22 @@ void insertExplicitReshards(Operation* op,
147142
OpShardingRuleAttr shardingRule,
148143
const SymbolTable& symbolTable, const Mesh& mesh);
149144

150-
// Inserts an `sdy.all-reduce` for each result of `op` if `reductionAxes`
151-
// is non-empty. Assume the followings:
145+
// Inserts an `sdy.all-reduce` for each result of `op`.
146+
//
147+
// Assumes the followings:
152148
// - All op results have the same unreduced axes.
153149
// - All op results have the same mesh as `mesh` ignoring device id orders.
154-
void insertAllReducesForReductionFactors(Operation* op,
155-
ArrayRef<AxisRefAttr> reductionAxes,
156-
const Mesh& mesh,
157-
IRRewriter& rewriter);
150+
// - `commonAxesPerFactor` is non-empty if `onFullVersion` is true.
151+
//
152+
// In case `onFullVersion` is false, it inserts all reduces only if op results
153+
// have some unreduced axes.
154+
//
155+
// Hard fails if the reduction factors do not have compatible shardings, and op
156+
// results have unreduced axes.
157+
void insertAllReducesForReductionFactors(
158+
Operation* op, const ShardingProjection& shardingProjection,
159+
AxesPerFactor& commonAxesPerFactor, OpShardingRuleAttr shardingRule,
160+
const Mesh& mesh, IRRewriter& rewriter, bool onFullVersion);
158161

159162
// Finds common factor axes on the operands and results of `op` so that the
160163
// sharding of `op` is compatible with its sharding rule.

shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc

Lines changed: 43 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -158,18 +158,14 @@ void insertExplicitReshardsOnDataFlowOp(ShardableDataFlowOpInterface& op,
158158
// return %reshard : tensor<4x8xf32>
159159
// ```
160160
template <class OpTy>
161-
void processDot(OpTy op, ArrayRef<TensorShardingAttr> inShardings,
161+
void processDot(OpTy op, ShardingProjection& shardingProjection,
162162
ArrayRef<TensorShardingAttr> outShardings, IRRewriter& rewriter,
163163
const SymbolTable& symbolTable, OpShardingRuleAttr shardingRule,
164164
const Mesh& mesh) {
165165
if (outShardings.empty()) {
166166
// Result doesn't have a sharding.
167167
return;
168168
}
169-
ShardingProjection shardingProjection =
170-
ShardingProjection::build(inShardings, outShardings, shardingRule,
171-
mesh.attr(), /*closedIfMissing=*/true);
172-
173169
const TensorFactorShardings& lhsSharding = shardingProjection.getOperand(0);
174170
const TensorFactorShardings& rhsSharding = shardingProjection.getOperand(1);
175171
TensorFactorShardings& resultSharding =
@@ -400,31 +396,25 @@ bool isOnFullVersion(Operation* op, const bool enableFullVersion) {
400396
// - All op results have the same unreduced axes.
401397
// - If the op has no results, none of the operands has unreduced axes.
402398
// - Operand and result meshes are the same ignoring device id order.
399+
// - There are no overflow axes.
403400
//
404401
// Returns the union of axes along all the reduction factors which may not be
405402
// canonicalized.
406-
SmallVector<AxisRefAttr> processOp(Operation* op,
407-
ArrayRef<TensorShardingAttr> inShardings,
408-
ArrayRef<TensorShardingAttr> outShardings,
409-
IRRewriter& rewriter,
410-
const SymbolTable& symbolTable,
411-
OpShardingRuleAttr shardingRule,
412-
const Mesh& mesh, const bool onFullVersion) {
413-
ShardingProjection shardingProjection = ShardingProjection::build(
414-
inShardings, outShardings, shardingRule, mesh.attr(),
415-
/*closedIfMissing=*/true);
416-
417-
// Return without inserting reshards if any factor sharding has overflow
418-
// axes. This case is not handled yet.
419-
// TODO(enver): Handle the case when factor shardings have overflow axes.
420-
if (hasOverflowAxes(shardingProjection)) {
421-
return {};
422-
}
423-
403+
//
404+
// Guarantees to return non-empty `AxesPerFactor` if `onFullVersion` is true.
405+
AxesPerFactor processOp(Operation* op, ShardingProjection& shardingProjection,
406+
ArrayRef<TensorShardingAttr> inShardings,
407+
ArrayRef<TensorShardingAttr> outShardings,
408+
IRRewriter& rewriter, const SymbolTable& symbolTable,
409+
OpShardingRuleAttr shardingRule, const Mesh& mesh,
410+
const bool onFullVersion) {
411+
// Checks if factors are sharded the same way across operands and results.
412+
AxesPerFactor commonAxesPerFactor =
413+
getCompatibleFactorShardings(shardingProjection, shardingRule);
414+
415+
// TODO(b/446833985): Return common axes per factor also when the sharding
416+
// projection have overflow axes.
424417
if (onFullVersion) {
425-
// Checks if factors are sharded the same way across operands and results.
426-
AxesPerFactor commonAxesPerFactor =
427-
getCompatibleFactorShardings(shardingProjection, shardingRule);
428418
// Find compatible shardings if it is not already compatible.
429419
if (commonAxesPerFactor.empty()) {
430420
commonAxesPerFactor =
@@ -443,49 +433,19 @@ SmallVector<AxisRefAttr> processOp(Operation* op,
443433
insertExplicitReshards(op, inShardings, outShardings, shardingProjection,
444434
updateTensorShardings, rewriter, shardingRule,
445435
symbolTable, mesh);
446-
447-
return getReductionAxes(commonAxesPerFactor, shardingRule);
448-
}
449-
450-
TypeSwitch<Operation*>(op)
451-
.Case<stablehlo::DotOp>([&](stablehlo::DotOp dotOp) {
452-
processDot(dotOp, inShardings, outShardings, rewriter, symbolTable,
453-
shardingRule, mesh);
454-
})
455-
.Case<stablehlo::DotGeneralOp>([&](stablehlo::DotGeneralOp dotGeneralOp) {
456-
processDot(dotGeneralOp, inShardings, outShardings, rewriter,
457-
symbolTable, shardingRule, mesh);
458-
});
459-
460-
if (outShardings.empty() || getUnreducedAxes(outShardings[0]).empty()) {
461-
return {};
436+
} else {
437+
TypeSwitch<Operation*>(op)
438+
.Case<stablehlo::DotOp>([&](stablehlo::DotOp dotOp) {
439+
processDot(dotOp, shardingProjection, outShardings, rewriter,
440+
symbolTable, shardingRule, mesh);
441+
})
442+
.Case<stablehlo::DotGeneralOp>(
443+
[&](stablehlo::DotGeneralOp dotGeneralOp) {
444+
processDot(dotGeneralOp, shardingProjection, outShardings,
445+
rewriter, symbolTable, shardingRule, mesh);
446+
});
462447
}
463-
464-
// TODO(enver): Repurpose getCompatibleFactorShardings to return compatible
465-
// factors, and simplify the following logic.
466-
SmallVector<AxisRefAttr> axesAlongAllReductionFactors;
467-
for (int64_t reductionFactor : shardingRule.getReductionFactors()) {
468-
// We only iterate operands since reduction factors are not in results.
469-
bool seen = false;
470-
SmallVector<AxisRefAttr> axesAlongCurrentReductionFactor;
471-
for (const TensorFactorShardings& tensorFactorSharding :
472-
shardingProjection.getOperands()) {
473-
if (std::optional<ArrayRef<AxisRefAttr>> factorSharding =
474-
getFactorSharding(tensorFactorSharding, reductionFactor)) {
475-
if (seen) {
476-
SDY_CHECK(axesAlongCurrentReductionFactor == *factorSharding)
477-
<< "For the operation " << op
478-
<< ", the result has unreduced axes while the operand has "
479-
"incompatible sharding along reduction factors.";
480-
} else {
481-
axesAlongCurrentReductionFactor = llvm::to_vector(*factorSharding);
482-
seen = true;
483-
}
484-
}
485-
}
486-
axesAlongAllReductionFactors.append(axesAlongCurrentReductionFactor);
487-
}
488-
return axesAlongAllReductionFactors;
448+
return commonAxesPerFactor;
489449
}
490450

491451
struct InsertExplicitReshardsPass
@@ -544,11 +504,22 @@ struct InsertExplicitReshardsPass
544504
return;
545505
}
546506

547-
SmallVector<AxisRefAttr> reductionAxes =
548-
processOp(op, inShardings, outShardings, rewriter, symbolTable,
549-
shardingRule, *mesh, onFullVersion);
507+
ShardingProjection shardingProjection = ShardingProjection::build(
508+
inShardings, outShardings, shardingRule, mesh->attr(),
509+
/*closedIfMissing=*/true);
510+
// Return without inserting reshards if any factor sharding has overflow
511+
// axes. This case is not handled yet.
512+
// TODO(enver): Handle the case when factor shardings have overflow axes.
513+
if (hasOverflowAxes(shardingProjection)) {
514+
return;
515+
}
516+
AxesPerFactor commonAxesPerFactor =
517+
processOp(op, shardingProjection, inShardings, outShardings, rewriter,
518+
symbolTable, shardingRule, *mesh, onFullVersion);
550519
// TODO(b/440055868): Insert a reshard from unreduced to replicated axes.
551-
insertAllReducesForReductionFactors(op, reductionAxes, *mesh, rewriter);
520+
insertAllReducesForReductionFactors(op, shardingProjection,
521+
commonAxesPerFactor, shardingRule,
522+
*mesh, rewriter, onFullVersion);
552523

553524
// TODO(enver): Remove sharding rules from ops.
554525
});

0 commit comments

Comments
 (0)