@@ -79,11 +79,11 @@ bool hasShardedPermutationFactors(
7979// 2. Factors that need replication are unsharded.
8080//
8181// Returns the common axes per factor if the factor sharding is compatible.
82- // Otherwise, returns std::nullopt .
82+ // Otherwise, returns empty AxesPerFactor .
8383//
8484// Assumes factor shardings do not have overflow axes.
8585// TODO(enver): Handle the case when some factor shardings have overflow axes.
86- std::optional< AxesPerFactor> getCompatibleFactorShardings (
86+ AxesPerFactor getCompatibleFactorShardings (
8787 const ShardingProjection& shardingProjection,
8888 OpShardingRuleAttr shardingRule) {
8989 AxesPerFactor commonAxesPerFactor (shardingRule.getNumFactors ());
@@ -98,15 +98,15 @@ std::optional<AxesPerFactor> getCompatibleFactorShardings(
9898 // and results in order for it to have a compatible sharding.
9999 if (shardingRule.isNeedReplicationFactor (factorIndex)) {
100100 if (!factorSharding.axisRefs .empty ()) {
101- return std:: nullopt ;
101+ return {} ;
102102 }
103103 continue ;
104104 }
105105 if (!seenFactors.test (factorIndex)) {
106106 commonAxesPerFactor[factorIndex] = factorSharding.axisRefs ;
107107 seenFactors.set (factorIndex);
108108 } else if (factorSharding.axisRefs != commonAxesPerFactor[factorIndex]) {
109- return std:: nullopt ;
109+ return {} ;
110110 }
111111 }
112112 }
@@ -590,6 +590,8 @@ AxesPerFactor toAxesPerFactor(const SmallVector<AxisListRef>& factorAxisRefs) {
590590// delete all the pairs from the list that is either with the picked factor,
591591// or with an axis that overlaps with the picked axis. Continue iterating
592592// until the list is empty.
593+ //
594+ // Guarantees to return a non-empty AxesPerFactor.
593595AxesPerFactor findCommonAxesUsingMajorityVoteHeuristic (
594596 const ShardingProjection& shardingProjection,
595597 OpShardingRuleAttr shardingRule, ArrayRef<int64_t > tensorSizes,
@@ -675,7 +677,7 @@ AxesPerFactor findCommonAxesUsingMajorityVoteHeuristic(
675677 return toAxesPerFactor (factorAxisRefs);
676678}
677679
678- std::optional< int64_t > findTensorIndexToPreferOnUnaryOperation (
680+ int64_t findTensorIndexToPreferOnUnaryOperation (
679681 const ShardingProjection& shardingProjection,
680682 OpShardingRuleAttr shardingRule, ArrayRef<int64_t > tensorSizes,
681683 const Mesh& mesh) {
@@ -708,21 +710,16 @@ std::optional<int64_t> findTensorIndexToPreferOnUnaryOperation(
708710// 1. Either tensor does not have factors that need replication.
709711// 2. Both tensors have the same mesh but may have different device orders.
710712// 3. The factor shardings are not compatible.
713+ //
714+ // Guarantees to return a non-empty AxesPerFactor.
711715AxesPerFactor findCommonAxesOnUnaryOperation (
712716 ArrayRef<TensorShardingAttr> inShardings,
713717 ArrayRef<TensorShardingAttr> outShardings,
714718 const ShardingProjection& shardingProjection,
715719 OpShardingRuleAttr shardingRule, ArrayRef<int64_t > tensorSizes,
716720 const SymbolTable& symbolTable, const Mesh& mesh) {
717- std::optional<int64_t > tensorIndexToPrefer =
718- findTensorIndexToPreferOnUnaryOperation (shardingProjection, shardingRule,
719- tensorSizes, mesh);
720-
721- // If one tensor can not be chosen to be common axes, return empty so it skips
722- // inserting explicit reshards for the operation.
723- if (tensorIndexToPrefer == std::nullopt ) {
724- return AxesPerFactor ();
725- }
721+ int64_t tensorIndexToPrefer = findTensorIndexToPreferOnUnaryOperation (
722+ shardingProjection, shardingRule, tensorSizes, mesh);
726723
727724 // Set factor shardings to make sure factors that do not appear in the
728725 // preferred tensor are sharded on the other tensor.
@@ -743,7 +740,7 @@ AxesPerFactor findCommonAxesOnUnaryOperation(
743740
744741 // Override with the factor shardings on the preferred tensor.
745742 for (const auto & [factorIndex, factorSharding] :
746- shardingProjection.getTensor (* tensorIndexToPrefer)
743+ shardingProjection.getTensor (tensorIndexToPrefer)
747744 .factorIndexToSharding ) {
748745 factorAxisRefs[factorIndex] = factorSharding.axisRefs ;
749746 }
@@ -797,23 +794,20 @@ void distributeAxisRefsToBatchingFactors(
797794 }
798795}
799796
797+ // Assumes there are no overflow axes.
798+ //
799+ // Guarantees to return a non-empty AxesPerFactor.
800800AxesPerFactor findCommonAxes (ArrayRef<TensorShardingAttr> inShardings,
801801 ArrayRef<TensorShardingAttr> outShardings,
802802 const ShardingProjection& shardingProjection,
803803 OpShardingRuleAttr shardingRule,
804804 ArrayRef<int64_t > tensorSizes,
805805 const SymbolTable& symbolTable, const Mesh& mesh) {
806- // Return without inserting reshards if any factor sharding has overflow
807- // axes. This case is not handled yet.
808- // TODO(enver): Handle the case when factor shardings have overflow axes.
809- if (hasOverflowAxes (shardingProjection)) {
810- return AxesPerFactor ();
811- }
812-
813806 // Checks if factors are sharded the same way across operands and results.
814- if (std::optional<AxesPerFactor> commonAxesPerFactor =
815- getCompatibleFactorShardings (shardingProjection, shardingRule)) {
816- return std::move (commonAxesPerFactor.value ());
807+ if (AxesPerFactor commonAxesPerFactor =
808+ getCompatibleFactorShardings (shardingProjection, shardingRule);
809+ !commonAxesPerFactor.empty ()) {
810+ return commonAxesPerFactor;
817811 }
818812
819813 // Handle the special case of unary operations without factors that need
@@ -969,14 +963,18 @@ SmallVector<AxisRefAttr> insertExplicitReshardsOnOp(
969963
970964 UpdateTensorShardings updateTensorShardings (shardingRule.getNumOperands (),
971965 shardingRule.getNumResults ());
966+
967+ // Return without inserting reshards if any factor sharding has overflow
968+ // axes. This case is not handled yet.
969+ // TODO(b/446833985): Handle the case when factor shardings have overflow
970+ // axes.
971+ if (hasOverflowAxes (shardingProjection)) {
972+ return {};
973+ }
974+
972975 AxesPerFactor commonAxesPerFactor =
973976 findCommonAxes (inShardings, outShardings, shardingProjection,
974977 shardingRule, getTensorSizes (op), symbolTable, mesh);
975- // TODO(b/446833985): Return common axes factors also when the sharding
976- // projection have overflow axes.
977- if (commonAxesPerFactor.empty ()) {
978- return {};
979- }
980978 for (const auto & [index, axes] : llvm::enumerate (commonAxesPerFactor)) {
981979 // TODO(enver): Add unit tests to test overflow axes are cleared after
982980 // handling the case that some factors have overflow axes.
0 commit comments