@@ -158,71 +158,6 @@ bool shouldReshardToCommonMesh(TensorShardingAttr sharding, const Mesh& mesh,
158158 mesh.attr ().getDeviceIds ();
159159}
160160
161- // Insert explicit reshards for operands and results that change by
162- // the given `shardingProjection` for a given `op`. The reshards are inserted
163- // only to make the given operation compatible.
164- //
165- // For example,
166- //
167- // ```mlir
168- // %arg0: tensor<8x32xf32> { sdy.sharding = @mesh, [{}, {"y"}]>}
169- // %arg1: tensor<32x16xf32> { sdy.sharding = <@mesh, [{"y"}, {"x"}]>}
170- // %0 = stablehlo.dot %arg0, %arg1 { sdy.sharding = <@mesh, [{"x"}, {}]>,
171- // sdy.sharding_rule = <([i, k], [k, j])->([i, j])> }
172- // %1 = stablehlo.negate %0 {sdy.sharding = <@mesh, [{"x"}, {}]>
173- // return %1
174- // ```
175- //
176- // after a call on the stablehlo.dot operation, by the sharding projection,
177- // i: {}, j: {}, k: {"y"}, the module becomes:
178- //
179- // ```mlir
180- // %arg0: tensor<8x32xf32> { sdy.sharding = @mesh, [{}, {"y"}]>}
181- // %arg1: tensor<32x16xf32> { sdy.sharding = <@mesh, [{"y"}, {"x"}]>}
182- // %0 = stablehlo.reshard %arg1 {sdy.sharding = <@mesh, [{"y"}, {}]>}
183- // %1 = stablehlo.dot %arg0, %0 { sdy.sharding = <@mesh, [{}, {}]>,
184- // sdy.sharding_rule = <([i, k], [k, j])->([i, j])> }
185- // %2 = stablehlo.reshard %1 {sdy.sharding = <@mesh, [{"x"}, {}]>}
186- // %3 = stablehlo.negate %2 {sdy.sharding = <@mesh, [{"x"}, {}]>
187- // return %3
188- // ```
189- //
190- // In the above example, note that the operand and result shardings for
191- // stablehlo.negate op remained unchanged.
192- //
193- // Assumes factor shardings do not have overflow axes.
194- // TODO(enver): Handle the case when some factor shardings have overflow axes.
195- //
196- // Assumes all tensor shardings have the same mesh as `mesh` on axes but may be
197- // different on device order.
198- void insertExplicitReshards (Operation* op,
199- ArrayRef<TensorShardingAttr> inShardings,
200- ArrayRef<TensorShardingAttr> outShardings,
201- const ShardingProjection& shardingProjection,
202- UpdateTensorShardings updateTensorShardings,
203- IRRewriter& rewriter,
204- OpShardingRuleAttr shardingRule,
205- const SymbolTable& symbolTable, const Mesh& mesh) {
206- rewriter.setInsertionPoint (op);
207- for (const auto & [operandIndex, operandSharding] :
208- llvm::enumerate (inShardings)) {
209- if (updateTensorShardings.updateOperands .test (operandIndex) ||
210- shouldReshardToCommonMesh (operandSharding, mesh, symbolTable)) {
211- insertExplicitReshardsOnOperand (op, operandIndex, shardingProjection,
212- shardingRule, mesh, rewriter);
213- }
214- }
215- rewriter.setInsertionPointAfter (op);
216- for (const auto & [resultIndex, resultSharding] :
217- llvm::enumerate (outShardings)) {
218- if (updateTensorShardings.updateResults .test (resultIndex) ||
219- shouldReshardToCommonMesh (resultSharding, mesh, symbolTable)) {
220- insertExplicitReshardsOnResult (op, resultIndex, shardingProjection,
221- shardingRule, mesh, rewriter);
222- }
223- }
224- }
225-
226161struct FactorAxesPair {
227162 constexpr static int64_t kEmptyFactorIndex = -1 ;
228163 constexpr static int64_t kTombstoneFactorIndex = -2 ;
@@ -796,6 +731,7 @@ void distributeAxisRefsToBatchingFactors(
796731 }
797732 }
798733}
734+ } // namespace
799735
800736AxesPerFactor findCommonAxes (ArrayRef<TensorShardingAttr> inShardings,
801737 ArrayRef<TensorShardingAttr> outShardings,
@@ -861,8 +797,6 @@ SmallVector<int64_t> getTensorSizes(Operation* op) {
861797 return tensorSizes;
862798}
863799
864- // Returns reduction axes that are the union of all axes on reduction factors.
865- // The result axes are not necessarilly canonicalized.
866800SmallVector<AxisRefAttr> getReductionAxes (const AxesPerFactor& axesPerFactor,
867801 OpShardingRuleAttr shardingRule) {
868802 SmallVector<AxisRefAttr> reductionAxes;
@@ -871,7 +805,6 @@ SmallVector<AxisRefAttr> getReductionAxes(const AxesPerFactor& axesPerFactor,
871805 }
872806 return reductionAxes;
873807}
874- } // namespace
875808
876809TensorShardingAttr insertAllReduceIfUnreducedToReplicated (
877810 OpOperand& use, TensorShardingAttr sourceSharding,
@@ -958,36 +891,32 @@ void insertAllReducesForReductionFactors(Operation* op,
958891 }
959892}
960893
961- SmallVector<AxisRefAttr> insertExplicitReshardsOnOp (
962- Operation* op, ArrayRef<TensorShardingAttr> inShardings,
963- ArrayRef<TensorShardingAttr> outShardings, IRRewriter& rewriter,
964- const SymbolTable& symbolTable, OpShardingRuleAttr shardingRule,
965- const Mesh& mesh) {
966- ShardingProjection shardingProjection = ShardingProjection::build (
967- inShardings, outShardings, shardingRule, mesh.attr (),
968- /* closedIfMissing=*/ true );
969-
970- UpdateTensorShardings updateTensorShardings (shardingRule.getNumOperands (),
971- shardingRule.getNumResults ());
972- AxesPerFactor commonAxesPerFactor =
973- findCommonAxes (inShardings, outShardings, shardingProjection,
974- shardingRule, getTensorSizes (op), symbolTable, mesh);
975- // TODO(b/446833985): Return common axes factors also when the sharding
976- // projection have overflow axes.
977- if (commonAxesPerFactor.empty ()) {
978- return {};
979- }
980- for (const auto & [index, axes] : llvm::enumerate (commonAxesPerFactor)) {
981- // TODO(enver): Add unit tests to test overflow axes are cleared after
982- // handling the case that some factors have overflow axes.
983- updateTensorShardings |=
984- shardingProjection.updateSharding (index, axes, /* overflowAxes=*/ {});
985- }
986- insertExplicitReshards (op, inShardings, outShardings, shardingProjection,
987- updateTensorShardings, rewriter, shardingRule,
988- symbolTable, mesh);
989-
990- return getReductionAxes (commonAxesPerFactor, shardingRule);
894+ void insertExplicitReshards (Operation* op,
895+ ArrayRef<TensorShardingAttr> inShardings,
896+ ArrayRef<TensorShardingAttr> outShardings,
897+ const ShardingProjection& shardingProjection,
898+ UpdateTensorShardings updateTensorShardings,
899+ IRRewriter& rewriter,
900+ OpShardingRuleAttr shardingRule,
901+ const SymbolTable& symbolTable, const Mesh& mesh) {
902+ rewriter.setInsertionPoint (op);
903+ for (const auto & [operandIndex, operandSharding] :
904+ llvm::enumerate (inShardings)) {
905+ if (updateTensorShardings.updateOperands .test (operandIndex) ||
906+ shouldReshardToCommonMesh (operandSharding, mesh, symbolTable)) {
907+ insertExplicitReshardsOnOperand (op, operandIndex, shardingProjection,
908+ shardingRule, mesh, rewriter);
909+ }
910+ }
911+ rewriter.setInsertionPointAfter (op);
912+ for (const auto & [resultIndex, resultSharding] :
913+ llvm::enumerate (outShardings)) {
914+ if (updateTensorShardings.updateResults .test (resultIndex) ||
915+ shouldReshardToCommonMesh (resultSharding, mesh, symbolTable)) {
916+ insertExplicitReshardsOnResult (op, resultIndex, shardingProjection,
917+ shardingRule, mesh, rewriter);
918+ }
919+ }
991920}
992921
993922} // namespace sdy
0 commit comments