Skip to content

Commit 48477c8

Browse files
Refactor to move logic related to allreduce to the method for it.
PiperOrigin-RevId: 812756921
1 parent b426498 commit 48477c8

3 files changed

Lines changed: 119 additions & 84 deletions

File tree

shardy/dialect/sdy/transforms/export/explicit_reshards_util.cc

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,10 @@ SmallVector<int64_t> getTensorSizes(Operation* op) {
804804
return tensorSizes;
805805
}
806806

807+
namespace {
808+
809+
// Returns reduction axes that are the union of all axes on reduction factors.
810+
// The result axes are not necessarilly canonicalized.
807811
SmallVector<AxisRefAttr> getReductionAxes(const AxesPerFactor& axesPerFactor,
808812
OpShardingRuleAttr shardingRule) {
809813
SmallVector<AxisRefAttr> reductionAxes;
@@ -813,6 +817,43 @@ SmallVector<AxisRefAttr> getReductionAxes(const AxesPerFactor& axesPerFactor,
813817
return reductionAxes;
814818
}
815819

820+
// Assume the results have unreduced axes.
821+
//
822+
// Populates `commonAxesPerFactor` with only its reduction factors are populated
823+
// to have common axes.
824+
//
825+
// Hard fails if some reduction factors do not have compatible shardings.
826+
void populateCommonAxesPerReductionFactorOrDie(
827+
Operation* op, const ShardingProjection& shardingProjection,
828+
OpShardingRuleAttr shardingRule, AxesPerFactor& commonAxesPerFactor) {
829+
// TODO(enver): Repurpose getCompatibleFactorShardings to return compatible
830+
// factors, and simplify the following logic.
831+
commonAxesPerFactor = AxesPerFactor(shardingRule.getNumFactors());
832+
for (int64_t reductionFactor : shardingRule.getReductionFactors()) {
833+
// We only iterate operands since reduction factors are not in results.
834+
bool seen = false;
835+
SmallVector<AxisRefAttr>& commonAxes = commonAxesPerFactor[reductionFactor];
836+
for (const TensorFactorShardings& tensorFactorSharding :
837+
shardingProjection.getOperands()) {
838+
if (std::optional<ArrayRef<AxisRefAttr>> factorSharding =
839+
getFactorSharding(tensorFactorSharding, reductionFactor)) {
840+
SmallVector<AxisRefAttr> factorShardingVector =
841+
llvm::to_vector(*factorSharding);
842+
if (seen) {
843+
SDY_CHECK(factorShardingVector == commonAxes)
844+
<< "For the operation " << op
845+
<< ", the result has unreduced axes while the operand has "
846+
"incompatible sharding along reduction factors.";
847+
} else {
848+
commonAxes = factorShardingVector;
849+
seen = true;
850+
}
851+
}
852+
}
853+
}
854+
}
855+
} // namespace
856+
816857
TensorShardingAttr insertAllReduceIfUnreducedToReplicated(
817858
OpOperand& use, TensorShardingAttr sourceSharding,
818859
TensorShardingAttr userSharding, const SymbolTable& symbolTable,
@@ -869,11 +910,26 @@ ArrayRef<AxisRefAttr> getUnreducedAxes(Value value) {
869910
return getUnreducedAxes(getSharding(value));
870911
}
871912

872-
void insertAllReducesForReductionFactors(Operation* op,
873-
ArrayRef<AxisRefAttr> reductionAxes,
874-
const Mesh& mesh,
875-
IRRewriter& rewriter) {
876-
if (reductionAxes.empty() || op->getResults().empty()) {
913+
void insertAllReducesForReductionFactors(
914+
Operation* op, const ShardingProjection& shardingProjection,
915+
AxesPerFactor& commonAxesPerFactor, OpShardingRuleAttr shardingRule,
916+
const Mesh& mesh, IRRewriter& rewriter, const bool onFullVersion) {
917+
if (op->getResults().empty()) {
918+
return;
919+
}
920+
921+
if (!onFullVersion && getUnreducedAxes(op->getResult(0)).empty()) {
922+
return;
923+
}
924+
if (commonAxesPerFactor.empty()) {
925+
// At this point, there are unreduced axes on results.
926+
populateCommonAxesPerReductionFactorOrDie(
927+
op, shardingProjection, shardingRule, commonAxesPerFactor);
928+
}
929+
930+
SmallVector<AxisRefAttr> reductionAxes =
931+
getReductionAxes(commonAxesPerFactor, shardingRule);
932+
if (reductionAxes.empty()) {
877933
return;
878934
}
879935

shardy/dialect/sdy/transforms/export/explicit_reshards_util.h

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,6 @@ ArrayRef<AxisRefAttr> getUnreducedAxes(Value value);
8181
// Returns a concatenated array of operand and result tensor sizes.
8282
SmallVector<int64_t> getTensorSizes(Operation* op);
8383

84-
// Returns reduction axes that are the union of all axes on reduction factors.
85-
// The result axes are not necessarilly canonicalized.
86-
SmallVector<AxisRefAttr> getReductionAxes(const AxesPerFactor& axesPerFactor,
87-
OpShardingRuleAttr shardingRule);
88-
8984
// Returns true iff any tensor factor sharding has non-empty overflow axes.
9085
bool hasOverflowAxes(const ShardingProjection& shardingProjection);
9186

@@ -147,19 +142,28 @@ void insertExplicitReshards(Operation* op,
147142
OpShardingRuleAttr shardingRule,
148143
const SymbolTable& symbolTable, const Mesh& mesh);
149144

150-
// Inserts an `sdy.all-reduce` for each result of `op` if `reductionAxes`
151-
// is non-empty. Assume the followings:
145+
// Inserts an `sdy.all-reduce` for each result of `op`.
146+
//
147+
// Assumes the followings:
152148
// - All op results have the same unreduced axes.
153149
// - All op results have the same mesh as `mesh` ignoring device id orders.
154-
void insertAllReducesForReductionFactors(Operation* op,
155-
ArrayRef<AxisRefAttr> reductionAxes,
156-
const Mesh& mesh,
157-
IRRewriter& rewriter);
150+
// - If `commonAxesPerFactor` is nonempty, op has compatible shardings.
151+
//
152+
// In case `onFullVersion` is false, it inserts all reduces only if op results
153+
// have some unreduced axes.
154+
//
155+
// Hard fails if the reduction factors do not have compatible shardings, and op
156+
// results have unreduced axes.
157+
void insertAllReducesForReductionFactors(
158+
Operation* op, const ShardingProjection& shardingProjection,
159+
AxesPerFactor& commonAxesPerFactor, OpShardingRuleAttr shardingRule,
160+
const Mesh& mesh, IRRewriter& rewriter, bool onFullVersion);
158161

159162
// Finds common factor axes on the operands and results of `op` so that the
160163
// sharding of `op` is compatible with its sharding rule.
161164
//
162-
// Refer to the documentation of `InsertExplicitReshardsPass` for more details.
165+
// Refer to the documentation of `InsertExplicitReshardsPass` for more
166+
// details.
163167
//
164168
// Assume the followings:
165169
// - All op results have the same unreduced axes.

shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc

Lines changed: 42 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -400,31 +400,25 @@ bool isOnFullVersion(Operation* op, const bool enableFullVersion) {
400400
// - All op results have the same unreduced axes.
401401
// - If the op has no results, none of the operands has unreduced axes.
402402
// - Operand and result meshes are the same ignoring device id order.
403+
// - There are no overflow axes.
403404
//
404405
// Returns the union of axes along all the reduction factors which may not be
405406
// canonicalized.
406-
SmallVector<AxisRefAttr> processOp(Operation* op,
407-
ArrayRef<TensorShardingAttr> inShardings,
408-
ArrayRef<TensorShardingAttr> outShardings,
409-
IRRewriter& rewriter,
410-
const SymbolTable& symbolTable,
411-
OpShardingRuleAttr shardingRule,
412-
const Mesh& mesh, const bool onFullVersion) {
413-
ShardingProjection shardingProjection = ShardingProjection::build(
414-
inShardings, outShardings, shardingRule, mesh.attr(),
415-
/*closedIfMissing=*/true);
416-
417-
// Return without inserting reshards if any factor sharding has overflow
418-
// axes. This case is not handled yet.
419-
// TODO(enver): Handle the case when factor shardings have overflow axes.
420-
if (hasOverflowAxes(shardingProjection)) {
421-
return {};
422-
}
423-
407+
//
408+
// Guarantees to return non-empty `AxesPerFactor` if `onFullVersion` is true.
409+
AxesPerFactor processOp(Operation* op, ShardingProjection& shardingProjection,
410+
ArrayRef<TensorShardingAttr> inShardings,
411+
ArrayRef<TensorShardingAttr> outShardings,
412+
IRRewriter& rewriter, const SymbolTable& symbolTable,
413+
OpShardingRuleAttr shardingRule, const Mesh& mesh,
414+
const bool onFullVersion) {
415+
// Checks if factors are sharded the same way across operands and results.
416+
AxesPerFactor commonAxesPerFactor =
417+
getCompatibleFactorShardings(shardingProjection, shardingRule);
418+
419+
// TODO(b/446833985): Return common axes factors also when the sharding
420+
// projection have overflow axes.
424421
if (onFullVersion) {
425-
// Checks if factors are sharded the same way across operands and results.
426-
AxesPerFactor commonAxesPerFactor =
427-
getCompatibleFactorShardings(shardingProjection, shardingRule);
428422
// Find compatible shardings if it is not already compatible.
429423
if (commonAxesPerFactor.empty()) {
430424
commonAxesPerFactor =
@@ -443,49 +437,19 @@ SmallVector<AxisRefAttr> processOp(Operation* op,
443437
insertExplicitReshards(op, inShardings, outShardings, shardingProjection,
444438
updateTensorShardings, rewriter, shardingRule,
445439
symbolTable, mesh);
446-
447-
return getReductionAxes(commonAxesPerFactor, shardingRule);
440+
} else {
441+
TypeSwitch<Operation*>(op)
442+
.Case<stablehlo::DotOp>([&](stablehlo::DotOp dotOp) {
443+
processDot(dotOp, inShardings, outShardings, rewriter, symbolTable,
444+
shardingRule, mesh);
445+
})
446+
.Case<stablehlo::DotGeneralOp>(
447+
[&](stablehlo::DotGeneralOp dotGeneralOp) {
448+
processDot(dotGeneralOp, inShardings, outShardings, rewriter,
449+
symbolTable, shardingRule, mesh);
450+
});
448451
}
449-
450-
TypeSwitch<Operation*>(op)
451-
.Case<stablehlo::DotOp>([&](stablehlo::DotOp dotOp) {
452-
processDot(dotOp, inShardings, outShardings, rewriter, symbolTable,
453-
shardingRule, mesh);
454-
})
455-
.Case<stablehlo::DotGeneralOp>([&](stablehlo::DotGeneralOp dotGeneralOp) {
456-
processDot(dotGeneralOp, inShardings, outShardings, rewriter,
457-
symbolTable, shardingRule, mesh);
458-
});
459-
460-
if (outShardings.empty() || getUnreducedAxes(outShardings[0]).empty()) {
461-
return {};
462-
}
463-
464-
// TODO(enver): Repurpose getCompatibleFactorShardings to return compatible
465-
// factors, and simplify the following logic.
466-
SmallVector<AxisRefAttr> axesAlongAllReductionFactors;
467-
for (int64_t reductionFactor : shardingRule.getReductionFactors()) {
468-
// We only iterate operands since reduction factors are not in results.
469-
bool seen = false;
470-
SmallVector<AxisRefAttr> axesAlongCurrentReductionFactor;
471-
for (const TensorFactorShardings& tensorFactorSharding :
472-
shardingProjection.getOperands()) {
473-
if (std::optional<ArrayRef<AxisRefAttr>> factorSharding =
474-
getFactorSharding(tensorFactorSharding, reductionFactor)) {
475-
if (seen) {
476-
SDY_CHECK(axesAlongCurrentReductionFactor == *factorSharding)
477-
<< "For the operation " << op
478-
<< ", the result has unreduced axes while the operand has "
479-
"incompatible sharding along reduction factors.";
480-
} else {
481-
axesAlongCurrentReductionFactor = llvm::to_vector(*factorSharding);
482-
seen = true;
483-
}
484-
}
485-
}
486-
axesAlongAllReductionFactors.append(axesAlongCurrentReductionFactor);
487-
}
488-
return axesAlongAllReductionFactors;
452+
return commonAxesPerFactor;
489453
}
490454

491455
struct InsertExplicitReshardsPass
@@ -544,11 +508,22 @@ struct InsertExplicitReshardsPass
544508
return;
545509
}
546510

547-
SmallVector<AxisRefAttr> reductionAxes =
548-
processOp(op, inShardings, outShardings, rewriter, symbolTable,
549-
shardingRule, *mesh, onFullVersion);
511+
ShardingProjection shardingProjection = ShardingProjection::build(
512+
inShardings, outShardings, shardingRule, mesh->attr(),
513+
/*closedIfMissing=*/true);
514+
// Return without inserting reshards if any factor sharding has overflow
515+
// axes. This case is not handled yet.
516+
// TODO(enver): Handle the case when factor shardings have overflow axes.
517+
if (hasOverflowAxes(shardingProjection)) {
518+
return;
519+
}
520+
AxesPerFactor commonAxesPerFactor =
521+
processOp(op, shardingProjection, inShardings, outShardings, rewriter,
522+
symbolTable, shardingRule, *mesh, onFullVersion);
550523
// TODO(b/440055868): Insert a reshard from unreduced to replicated axes.
551-
insertAllReducesForReductionFactors(op, reductionAxes, *mesh, rewriter);
524+
insertAllReducesForReductionFactors(op, shardingProjection,
525+
commonAxesPerFactor, shardingRule,
526+
*mesh, rewriter, onFullVersion);
552527

553528
// TODO(enver): Remove sharding rules from ops.
554529
});

0 commit comments

Comments
 (0)