@@ -45,9 +45,6 @@ limitations under the License.
4545namespace mlir {
4646namespace sdy {
4747
48- namespace {
49-
50- // Returns true iff any tensor factor sharding has non-empty overflow axes.
5148bool hasOverflowAxes (const ShardingProjection& shardingProjection) {
5249 for (const TensorFactorShardings& tensorFactorSharding :
5350 llvm::concat<const TensorFactorShardings>(
@@ -62,6 +59,7 @@ bool hasOverflowAxes(const ShardingProjection& shardingProjection) {
6259 return false ;
6360}
6461
62+ namespace {
6563bool hasShardedPermutationFactors (
6664 const TensorFactorShardings& tensorFactorSharding,
6765 OpShardingRuleAttr shardingRule) {
@@ -157,44 +155,8 @@ bool shouldReshardToCommonMesh(TensorShardingAttr sharding, const Mesh& mesh,
157155 sharding.getMesh (symbolTable).getDeviceIds () !=
158156 mesh.attr ().getDeviceIds ();
159157}
158+ } // namespace
160159
161- // Insert explicit reshards for operands and results that change by
162- // the given `shardingProjection` for a given `op`. The reshards are inserted
163- // only to make the given operation compatible.
164- //
165- // For example,
166- //
167- // ```mlir
168- // %arg0: tensor<8x32xf32> { sdy.sharding = @mesh, [{}, {"y"}]>}
169- // %arg1: tensor<32x16xf32> { sdy.sharding = <@mesh, [{"y"}, {"x"}]>}
170- // %0 = stablehlo.dot %arg0, %arg1 { sdy.sharding = <@mesh, [{"x"}, {}]>,
171- // sdy.sharding_rule = <([i, k], [k, j])->([i, j])> }
172- // %1 = stablehlo.negate %0 {sdy.sharding = <@mesh, [{"x"}, {}]>
173- // return %1
174- // ```
175- //
176- // after a call on the stablehlo.dot operation, by the sharding projection,
177- // i: {}, j: {}, k: {"y"}, the module becomes:
178- //
179- // ```mlir
180- // %arg0: tensor<8x32xf32> { sdy.sharding = @mesh, [{}, {"y"}]>}
181- // %arg1: tensor<32x16xf32> { sdy.sharding = <@mesh, [{"y"}, {"x"}]>}
182- // %0 = stablehlo.reshard %arg1 {sdy.sharding = <@mesh, [{"y"}, {}]>}
183- // %1 = stablehlo.dot %arg0, %0 { sdy.sharding = <@mesh, [{}, {}]>,
184- // sdy.sharding_rule = <([i, k], [k, j])->([i, j])> }
185- // %2 = stablehlo.reshard %1 {sdy.sharding = <@mesh, [{"x"}, {}]>}
186- // %3 = stablehlo.negate %2 {sdy.sharding = <@mesh, [{"x"}, {}]>
187- // return %3
188- // ```
189- //
190- // In the above example, note that the operand and result shardings for
191- // stablehlo.negate op remained unchanged.
192- //
193- // Assumes factor shardings do not have overflow axes.
194- // TODO(enver): Handle the case when some factor shardings have overflow axes.
195- //
196- // Assumes all tensor shardings have the same mesh as `mesh` on axes but may be
197- // different on device order.
198160void insertExplicitReshards (Operation* op,
199161 ArrayRef<TensorShardingAttr> inShardings,
200162 ArrayRef<TensorShardingAttr> outShardings,
@@ -223,6 +185,7 @@ void insertExplicitReshards(Operation* op,
223185 }
224186}
225187
188+ namespace {
226189struct FactorAxesPair {
227190 constexpr static int64_t kEmptyFactorIndex = -1 ;
228191 constexpr static int64_t kTombstoneFactorIndex = -2 ;
@@ -793,6 +756,7 @@ void distributeAxisRefsToBatchingFactors(
793756 }
794757 }
795758}
759+ } // namespace
796760
797761// Assumes there are no overflow axes.
798762//
@@ -855,8 +819,6 @@ SmallVector<int64_t> getTensorSizes(Operation* op) {
855819 return tensorSizes;
856820}
857821
858- // Returns reduction axes that are the union of all axes on reduction factors.
859- // The result axes are not necessarilly canonicalized.
860822SmallVector<AxisRefAttr> getReductionAxes (const AxesPerFactor& axesPerFactor,
861823 OpShardingRuleAttr shardingRule) {
862824 SmallVector<AxisRefAttr> reductionAxes;
@@ -865,7 +827,6 @@ SmallVector<AxisRefAttr> getReductionAxes(const AxesPerFactor& axesPerFactor,
865827 }
866828 return reductionAxes;
867829}
868- } // namespace
869830
870831TensorShardingAttr insertAllReduceIfUnreducedToReplicated (
871832 OpOperand& use, TensorShardingAttr sourceSharding,
@@ -952,41 +913,5 @@ void insertAllReducesForReductionFactors(Operation* op,
952913 }
953914}
954915
955- SmallVector<AxisRefAttr> insertExplicitReshardsOnOp (
956- Operation* op, ArrayRef<TensorShardingAttr> inShardings,
957- ArrayRef<TensorShardingAttr> outShardings, IRRewriter& rewriter,
958- const SymbolTable& symbolTable, OpShardingRuleAttr shardingRule,
959- const Mesh& mesh) {
960- ShardingProjection shardingProjection = ShardingProjection::build (
961- inShardings, outShardings, shardingRule, mesh.attr (),
962- /* closedIfMissing=*/ true );
963-
964- UpdateTensorShardings updateTensorShardings (shardingRule.getNumOperands (),
965- shardingRule.getNumResults ());
966-
967- // Return without inserting reshards if any factor sharding has overflow
968- // axes. This case is not handled yet.
969- // TODO(b/446833985): Handle the case when factor shardings have overflow
970- // axes.
971- if (hasOverflowAxes (shardingProjection)) {
972- return {};
973- }
974-
975- AxesPerFactor commonAxesPerFactor =
976- findCommonAxes (inShardings, outShardings, shardingProjection,
977- shardingRule, getTensorSizes (op), symbolTable, mesh);
978- for (const auto & [index, axes] : llvm::enumerate (commonAxesPerFactor)) {
979- // TODO(enver): Add unit tests to test overflow axes are cleared after
980- // handling the case that some factors have overflow axes.
981- updateTensorShardings |=
982- shardingProjection.updateSharding (index, axes, /* overflowAxes=*/ {});
983- }
984- insertExplicitReshards (op, inShardings, outShardings, shardingProjection,
985- updateTensorShardings, rewriter, shardingRule,
986- symbolTable, mesh);
987-
988- return getReductionAxes (commonAxesPerFactor, shardingRule);
989- }
990-
991916} // namespace sdy
992917} // namespace mlir
0 commit comments