Skip to content

Commit 131a25f

Browse files
Refactor so that findCommonAxes always returns a non-empty AxesPerFactor.
PiperOrigin-RevId: 811849230
1 parent c8ff033 commit 131a25f

1 file changed

Lines changed: 28 additions & 30 deletions

File tree

shardy/dialect/sdy/transforms/export/explicit_reshards_util.cc

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,11 @@ bool hasShardedPermutationFactors(
7979
// 2. Factors that need replication are unsharded.
8080
//
8181
// Returns the common axes per factor if the factor sharding is compatible.
82-
// Otherwise, returns std::nullopt.
82+
// Otherwise, returns empty AxesPerFactor.
8383
//
8484
// Assumes factor shardings do not have overflow axes.
8585
// TODO(enver): Handle the case when some factor shardings have overflow axes.
86-
std::optional<AxesPerFactor> getCompatibleFactorShardings(
86+
AxesPerFactor getCompatibleFactorShardings(
8787
const ShardingProjection& shardingProjection,
8888
OpShardingRuleAttr shardingRule) {
8989
AxesPerFactor commonAxesPerFactor(shardingRule.getNumFactors());
@@ -98,15 +98,15 @@ std::optional<AxesPerFactor> getCompatibleFactorShardings(
9898
// and results in order for it to have a compatible sharding.
9999
if (shardingRule.isNeedReplicationFactor(factorIndex)) {
100100
if (!factorSharding.axisRefs.empty()) {
101-
return std::nullopt;
101+
return {};
102102
}
103103
continue;
104104
}
105105
if (!seenFactors.test(factorIndex)) {
106106
commonAxesPerFactor[factorIndex] = factorSharding.axisRefs;
107107
seenFactors.set(factorIndex);
108108
} else if (factorSharding.axisRefs != commonAxesPerFactor[factorIndex]) {
109-
return std::nullopt;
109+
return {};
110110
}
111111
}
112112
}
@@ -590,6 +590,8 @@ AxesPerFactor toAxesPerFactor(const SmallVector<AxisListRef>& factorAxisRefs) {
590590
// delete all the pairs from the list that is either with the picked factor,
591591
// or with an axis that overlaps with the picked axis. Continue iterating
592592
// until the list is empty.
593+
//
594+
// Guarantees to return a non-empty AxesPerFactor.
593595
AxesPerFactor findCommonAxesUsingMajorityVoteHeuristic(
594596
const ShardingProjection& shardingProjection,
595597
OpShardingRuleAttr shardingRule, ArrayRef<int64_t> tensorSizes,
@@ -675,7 +677,7 @@ AxesPerFactor findCommonAxesUsingMajorityVoteHeuristic(
675677
return toAxesPerFactor(factorAxisRefs);
676678
}
677679

678-
std::optional<int64_t> findTensorIndexToPreferOnUnaryOperation(
680+
int64_t findTensorIndexToPreferOnUnaryOperation(
679681
const ShardingProjection& shardingProjection,
680682
OpShardingRuleAttr shardingRule, ArrayRef<int64_t> tensorSizes,
681683
const Mesh& mesh) {
@@ -708,21 +710,16 @@ std::optional<int64_t> findTensorIndexToPreferOnUnaryOperation(
708710
// 1. Either tensor does not have factors that need replication.
709711
// 2. Both tensors have the same mesh but may have different device orders.
710712
// 3. The factor shardings are not compatible.
713+
//
714+
// Guarantees to return a non-empty AxesPerFactor.
711715
AxesPerFactor findCommonAxesOnUnaryOperation(
712716
ArrayRef<TensorShardingAttr> inShardings,
713717
ArrayRef<TensorShardingAttr> outShardings,
714718
const ShardingProjection& shardingProjection,
715719
OpShardingRuleAttr shardingRule, ArrayRef<int64_t> tensorSizes,
716720
const SymbolTable& symbolTable, const Mesh& mesh) {
717-
std::optional<int64_t> tensorIndexToPrefer =
718-
findTensorIndexToPreferOnUnaryOperation(shardingProjection, shardingRule,
719-
tensorSizes, mesh);
720-
721-
// If one tensor can not be chosen to be common axes, return empty so it skips
722-
// inserting explicit reshards for the operation.
723-
if (tensorIndexToPrefer == std::nullopt) {
724-
return AxesPerFactor();
725-
}
721+
int64_t tensorIndexToPrefer = findTensorIndexToPreferOnUnaryOperation(
722+
shardingProjection, shardingRule, tensorSizes, mesh);
726723

727724
// Set factor shardings to make sure factors that do not appear in the
728725
// preferred tensor are sharded on the other tensor.
@@ -743,7 +740,7 @@ AxesPerFactor findCommonAxesOnUnaryOperation(
743740

744741
// Override with the factor shardings on the preferred tensor.
745742
for (const auto& [factorIndex, factorSharding] :
746-
shardingProjection.getTensor(*tensorIndexToPrefer)
743+
shardingProjection.getTensor(tensorIndexToPrefer)
747744
.factorIndexToSharding) {
748745
factorAxisRefs[factorIndex] = factorSharding.axisRefs;
749746
}
@@ -797,23 +794,20 @@ void distributeAxisRefsToBatchingFactors(
797794
}
798795
}
799796

797+
// Assumes there are no overflow axes.
798+
//
799+
// Guarantees to return a non-empty AxesPerFactor.
800800
AxesPerFactor findCommonAxes(ArrayRef<TensorShardingAttr> inShardings,
801801
ArrayRef<TensorShardingAttr> outShardings,
802802
const ShardingProjection& shardingProjection,
803803
OpShardingRuleAttr shardingRule,
804804
ArrayRef<int64_t> tensorSizes,
805805
const SymbolTable& symbolTable, const Mesh& mesh) {
806-
// Return without inserting reshards if any factor sharding has overflow
807-
// axes. This case is not handled yet.
808-
// TODO(enver): Handle the case when factor shardings have overflow axes.
809-
if (hasOverflowAxes(shardingProjection)) {
810-
return AxesPerFactor();
811-
}
812-
813806
// Checks if factors are sharded the same way across operands and results.
814-
if (std::optional<AxesPerFactor> commonAxesPerFactor =
815-
getCompatibleFactorShardings(shardingProjection, shardingRule)) {
816-
return std::move(commonAxesPerFactor.value());
807+
if (AxesPerFactor commonAxesPerFactor =
808+
getCompatibleFactorShardings(shardingProjection, shardingRule);
809+
!commonAxesPerFactor.empty()) {
810+
return commonAxesPerFactor;
817811
}
818812

819813
// Handle the special case of unary operations without factors that need
@@ -969,14 +963,18 @@ SmallVector<AxisRefAttr> insertExplicitReshardsOnOp(
969963

970964
UpdateTensorShardings updateTensorShardings(shardingRule.getNumOperands(),
971965
shardingRule.getNumResults());
966+
967+
// Return without inserting reshards if any factor sharding has overflow
968+
// axes. This case is not handled yet.
969+
// TODO(b/446833985): Handle the case when factor shardings have overflow
970+
// axes.
971+
if (hasOverflowAxes(shardingProjection)) {
972+
return {};
973+
}
974+
972975
AxesPerFactor commonAxesPerFactor =
973976
findCommonAxes(inShardings, outShardings, shardingProjection,
974977
shardingRule, getTensorSizes(op), symbolTable, mesh);
975-
// TODO(b/446833985): Return common axes factors also when the sharding
976-
// projection have overflow axes.
977-
if (commonAxesPerFactor.empty()) {
978-
return {};
979-
}
980978
for (const auto& [index, axes] : llvm::enumerate(commonAxesPerFactor)) {
981979
// TODO(enver): Add unit tests to test overflow axes are cleared after
982980
// handling the case that some factors have overflow axes.

0 commit comments

Comments
 (0)