Skip to content

Commit 0626767

Browse files
Refactor to unify insert explicit reshards on op flows for default/minimal and full version.
It prepares for further refactorings to unify parts of the flow. PiperOrigin-RevId: 811354695
1 parent d944a51 commit 0626767

3 files changed

Lines changed: 123 additions & 113 deletions

File tree

shardy/dialect/sdy/transforms/export/explicit_reshards_util.cc

Lines changed: 27 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -158,71 +158,6 @@ bool shouldReshardToCommonMesh(TensorShardingAttr sharding, const Mesh& mesh,
158158
mesh.attr().getDeviceIds();
159159
}
160160

161-
// Insert explicit reshards for operands and results that change by
162-
// the given `shardingProjection` for a given `op`. The reshards are inserted
163-
// only to make the given operation compatible.
164-
//
165-
// For example,
166-
//
167-
// ```mlir
168-
// %arg0: tensor<8x32xf32> { sdy.sharding = @mesh, [{}, {"y"}]>}
169-
// %arg1: tensor<32x16xf32> { sdy.sharding = <@mesh, [{"y"}, {"x"}]>}
170-
// %0 = stablehlo.dot %arg0, %arg1 { sdy.sharding = <@mesh, [{"x"}, {}]>,
171-
// sdy.sharding_rule = <([i, k], [k, j])->([i, j])> }
172-
// %1 = stablehlo.negate %0 {sdy.sharding = <@mesh, [{"x"}, {}]>
173-
// return %1
174-
// ```
175-
//
176-
// after a call on the stablehlo.dot operation, by the sharding projection,
177-
// i: {}, j: {}, k: {"y"}, the module becomes:
178-
//
179-
// ```mlir
180-
// %arg0: tensor<8x32xf32> { sdy.sharding = @mesh, [{}, {"y"}]>}
181-
// %arg1: tensor<32x16xf32> { sdy.sharding = <@mesh, [{"y"}, {"x"}]>}
182-
// %0 = stablehlo.reshard %arg1 {sdy.sharding = <@mesh, [{"y"}, {}]>}
183-
// %1 = stablehlo.dot %arg0, %0 { sdy.sharding = <@mesh, [{}, {}]>,
184-
// sdy.sharding_rule = <([i, k], [k, j])->([i, j])> }
185-
// %2 = stablehlo.reshard %1 {sdy.sharding = <@mesh, [{"x"}, {}]>}
186-
// %3 = stablehlo.negate %2 {sdy.sharding = <@mesh, [{"x"}, {}]>
187-
// return %3
188-
// ```
189-
//
190-
// In the above example, note that the operand and result shardings for
191-
// stablehlo.negate op remained unchanged.
192-
//
193-
// Assumes factor shardings do not have overflow axes.
194-
// TODO(enver): Handle the case when some factor shardings have overflow axes.
195-
//
196-
// Assumes all tensor shardings have the same mesh as `mesh` on axes but may be
197-
// different on device order.
198-
void insertExplicitReshards(Operation* op,
199-
ArrayRef<TensorShardingAttr> inShardings,
200-
ArrayRef<TensorShardingAttr> outShardings,
201-
const ShardingProjection& shardingProjection,
202-
UpdateTensorShardings updateTensorShardings,
203-
IRRewriter& rewriter,
204-
OpShardingRuleAttr shardingRule,
205-
const SymbolTable& symbolTable, const Mesh& mesh) {
206-
rewriter.setInsertionPoint(op);
207-
for (const auto& [operandIndex, operandSharding] :
208-
llvm::enumerate(inShardings)) {
209-
if (updateTensorShardings.updateOperands.test(operandIndex) ||
210-
shouldReshardToCommonMesh(operandSharding, mesh, symbolTable)) {
211-
insertExplicitReshardsOnOperand(op, operandIndex, shardingProjection,
212-
shardingRule, mesh, rewriter);
213-
}
214-
}
215-
rewriter.setInsertionPointAfter(op);
216-
for (const auto& [resultIndex, resultSharding] :
217-
llvm::enumerate(outShardings)) {
218-
if (updateTensorShardings.updateResults.test(resultIndex) ||
219-
shouldReshardToCommonMesh(resultSharding, mesh, symbolTable)) {
220-
insertExplicitReshardsOnResult(op, resultIndex, shardingProjection,
221-
shardingRule, mesh, rewriter);
222-
}
223-
}
224-
}
225-
226161
struct FactorAxesPair {
227162
constexpr static int64_t kEmptyFactorIndex = -1;
228163
constexpr static int64_t kTombstoneFactorIndex = -2;
@@ -796,6 +731,7 @@ void distributeAxisRefsToBatchingFactors(
796731
}
797732
}
798733
}
734+
} // namespace
799735

800736
AxesPerFactor findCommonAxes(ArrayRef<TensorShardingAttr> inShardings,
801737
ArrayRef<TensorShardingAttr> outShardings,
@@ -861,8 +797,6 @@ SmallVector<int64_t> getTensorSizes(Operation* op) {
861797
return tensorSizes;
862798
}
863799

864-
// Returns reduction axes that are the union of all axes on reduction factors.
865-
// The result axes are not necessarilly canonicalized.
866800
SmallVector<AxisRefAttr> getReductionAxes(const AxesPerFactor& axesPerFactor,
867801
OpShardingRuleAttr shardingRule) {
868802
SmallVector<AxisRefAttr> reductionAxes;
@@ -871,7 +805,6 @@ SmallVector<AxisRefAttr> getReductionAxes(const AxesPerFactor& axesPerFactor,
871805
}
872806
return reductionAxes;
873807
}
874-
} // namespace
875808

876809
TensorShardingAttr insertAllReduceIfUnreducedToReplicated(
877810
OpOperand& use, TensorShardingAttr sourceSharding,
@@ -958,36 +891,32 @@ void insertAllReducesForReductionFactors(Operation* op,
958891
}
959892
}
960893

961-
SmallVector<AxisRefAttr> insertExplicitReshardsOnOp(
962-
Operation* op, ArrayRef<TensorShardingAttr> inShardings,
963-
ArrayRef<TensorShardingAttr> outShardings, IRRewriter& rewriter,
964-
const SymbolTable& symbolTable, OpShardingRuleAttr shardingRule,
965-
const Mesh& mesh) {
966-
ShardingProjection shardingProjection = ShardingProjection::build(
967-
inShardings, outShardings, shardingRule, mesh.attr(),
968-
/*closedIfMissing=*/true);
969-
970-
UpdateTensorShardings updateTensorShardings(shardingRule.getNumOperands(),
971-
shardingRule.getNumResults());
972-
AxesPerFactor commonAxesPerFactor =
973-
findCommonAxes(inShardings, outShardings, shardingProjection,
974-
shardingRule, getTensorSizes(op), symbolTable, mesh);
975-
// TODO(b/446833985): Return common axes factors also when the sharding
976-
// projection have overflow axes.
977-
if (commonAxesPerFactor.empty()) {
978-
return {};
979-
}
980-
for (const auto& [index, axes] : llvm::enumerate(commonAxesPerFactor)) {
981-
// TODO(enver): Add unit tests to test overflow axes are cleared after
982-
// handling the case that some factors have overflow axes.
983-
updateTensorShardings |=
984-
shardingProjection.updateSharding(index, axes, /*overflowAxes=*/{});
985-
}
986-
insertExplicitReshards(op, inShardings, outShardings, shardingProjection,
987-
updateTensorShardings, rewriter, shardingRule,
988-
symbolTable, mesh);
989-
990-
return getReductionAxes(commonAxesPerFactor, shardingRule);
894+
void insertExplicitReshards(Operation* op,
895+
ArrayRef<TensorShardingAttr> inShardings,
896+
ArrayRef<TensorShardingAttr> outShardings,
897+
const ShardingProjection& shardingProjection,
898+
UpdateTensorShardings updateTensorShardings,
899+
IRRewriter& rewriter,
900+
OpShardingRuleAttr shardingRule,
901+
const SymbolTable& symbolTable, const Mesh& mesh) {
902+
rewriter.setInsertionPoint(op);
903+
for (const auto& [operandIndex, operandSharding] :
904+
llvm::enumerate(inShardings)) {
905+
if (updateTensorShardings.updateOperands.test(operandIndex) ||
906+
shouldReshardToCommonMesh(operandSharding, mesh, symbolTable)) {
907+
insertExplicitReshardsOnOperand(op, operandIndex, shardingProjection,
908+
shardingRule, mesh, rewriter);
909+
}
910+
}
911+
rewriter.setInsertionPointAfter(op);
912+
for (const auto& [resultIndex, resultSharding] :
913+
llvm::enumerate(outShardings)) {
914+
if (updateTensorShardings.updateResults.test(resultIndex) ||
915+
shouldReshardToCommonMesh(resultSharding, mesh, symbolTable)) {
916+
insertExplicitReshardsOnResult(op, resultIndex, shardingProjection,
917+
shardingRule, mesh, rewriter);
918+
}
919+
}
991920
}
992921

993922
} // namespace sdy

shardy/dialect/sdy/transforms/export/explicit_reshards_util.h

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,14 @@ ArrayRef<AxisRefAttr> getUnreducedAxes(TensorShardingAttr sharding);
7878
// empty axes.
7979
ArrayRef<AxisRefAttr> getUnreducedAxes(Value value);
8080

81+
// Returns a concatenated array of operand and result tensor sizes.
82+
SmallVector<int64_t> getTensorSizes(Operation* op);
83+
84+
// Returns reduction axes that are the union of all axes on reduction factors.
85+
// The result axes are not necessarilly canonicalized.
86+
SmallVector<AxisRefAttr> getReductionAxes(const AxesPerFactor& axesPerFactor,
87+
OpShardingRuleAttr shardingRule);
88+
8189
// Inserts an `sdy.all-reduce` for each result of `op` if `reductionAxes`
8290
// is non-empty. Assume the followings:
8391
// - All op results have the same unreduced axes.
@@ -87,7 +95,7 @@ void insertAllReducesForReductionFactors(Operation* op,
8795
const Mesh& mesh,
8896
IRRewriter& rewriter);
8997

90-
// Inserts explicit reshards on the operands and results of `op` such that the
98+
// Finds common factor axes on the operands and results of `op` so that the
9199
// sharding of `op` is compatible with its sharding rule.
92100
//
93101
// Refer to the documentation of `InsertExplicitReshardsPass` for more details.
@@ -97,14 +105,59 @@ void insertAllReducesForReductionFactors(Operation* op,
97105
// - If the op has no results, none of the operands has unreduced axes.
98106
// - Operand and result meshes are the same ignoring device id order.
99107
//
100-
// Returns the union of axes along all the reduction factors which may not be
101-
// canonicalized.
102-
SmallVector<AxisRefAttr> insertExplicitReshardsOnOp(
103-
Operation* op, ArrayRef<TensorShardingAttr> inShardings,
104-
ArrayRef<TensorShardingAttr> outShardings, IRRewriter& rewriter,
105-
const SymbolTable& symbolTable, OpShardingRuleAttr shardingRule,
106-
const Mesh& mesh);
107-
108+
// Returns the common axes per factor.
109+
AxesPerFactor findCommonAxes(ArrayRef<TensorShardingAttr> inShardings,
110+
ArrayRef<TensorShardingAttr> outShardings,
111+
const ShardingProjection& shardingProjection,
112+
OpShardingRuleAttr shardingRule,
113+
ArrayRef<int64_t> tensorSizes,
114+
const SymbolTable& symbolTable, const Mesh& mesh);
115+
116+
// Insert explicit reshards for operands and results that change by
117+
// the given `shardingProjection` for a given `op`. The reshards are inserted
118+
// only to make the given operation compatible.
119+
//
120+
// For example,
121+
//
122+
// ```mlir
123+
// %arg0: tensor<8x32xf32> { sdy.sharding = @mesh, [{}, {"y"}]>}
124+
// %arg1: tensor<32x16xf32> { sdy.sharding = <@mesh, [{"y"}, {"x"}]>}
125+
// %0 = stablehlo.dot %arg0, %arg1 { sdy.sharding = <@mesh, [{"x"}, {}]>,
126+
// sdy.sharding_rule = <([i, k], [k, j])->([i, j])> }
127+
// %1 = stablehlo.negate %0 {sdy.sharding = <@mesh, [{"x"}, {}]>
128+
// return %1
129+
// ```
130+
//
131+
// after a call on the stablehlo.dot operation, by the sharding projection,
132+
// i: {}, j: {}, k: {"y"}, the module becomes:
133+
//
134+
// ```mlir
135+
// %arg0: tensor<8x32xf32> { sdy.sharding = @mesh, [{}, {"y"}]>}
136+
// %arg1: tensor<32x16xf32> { sdy.sharding = <@mesh, [{"y"}, {"x"}]>}
137+
// %0 = stablehlo.reshard %arg1 {sdy.sharding = <@mesh, [{"y"}, {}]>}
138+
// %1 = stablehlo.dot %arg0, %0 { sdy.sharding = <@mesh, [{}, {}]>,
139+
// sdy.sharding_rule = <([i, k], [k, j])->([i, j])> }
140+
// %2 = stablehlo.reshard %1 {sdy.sharding = <@mesh, [{"x"}, {}]>}
141+
// %3 = stablehlo.negate %2 {sdy.sharding = <@mesh, [{"x"}, {}]>
142+
// return %3
143+
// ```
144+
//
145+
// In the above example, note that the operand and result shardings for
146+
// stablehlo.negate op remained unchanged.
147+
//
148+
// Assumes factor shardings do not have overflow axes.
149+
// TODO(enver): Handle the case when some factor shardings have overflow axes.
150+
//
151+
// Assumes all tensor shardings have the same mesh as `mesh` on axes but may be
152+
// different on device order.
153+
void insertExplicitReshards(Operation* op,
154+
ArrayRef<TensorShardingAttr> inShardings,
155+
ArrayRef<TensorShardingAttr> outShardings,
156+
const ShardingProjection& shardingProjection,
157+
UpdateTensorShardings updateTensorShardings,
158+
IRRewriter& rewriter,
159+
OpShardingRuleAttr shardingRule,
160+
const SymbolTable& symbolTable, const Mesh& mesh);
108161
} // namespace sdy
109162
} // namespace mlir
110163

shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -391,21 +391,52 @@ bool isOnFullVersion(Operation* op, const bool enableFullVersion) {
391391
return false;
392392
}
393393

394+
// Inserts explicit reshards on the operands and results of `op` such that the
395+
// sharding of `op` is compatible with its sharding rule.
396+
//
397+
// Refer to the documentation of `InsertExplicitReshardsPass` for more details.
398+
//
394399
// Assume the followings:
395400
// - All op results have the same unreduced axes.
396401
// - If the op has no results, none of the operands has unreduced axes.
402+
// - Operand and result meshes are the same ignoring device id order.
397403
//
398-
// Returns the union of common reducation axes which may not be canonicalized.
404+
// Returns the union of axes along all the reduction factors which may not be
405+
// canonicalized.
399406
SmallVector<AxisRefAttr> processOp(Operation* op,
400407
ArrayRef<TensorShardingAttr> inShardings,
401408
ArrayRef<TensorShardingAttr> outShardings,
402409
IRRewriter& rewriter,
403410
const SymbolTable& symbolTable,
404411
OpShardingRuleAttr shardingRule,
405412
const Mesh& mesh, const bool onFullVersion) {
413+
ShardingProjection shardingProjection = ShardingProjection::build(
414+
inShardings, outShardings, shardingRule, mesh.attr(),
415+
/*closedIfMissing=*/true);
416+
406417
if (onFullVersion) {
407-
return insertExplicitReshardsOnOp(op, inShardings, outShardings, rewriter,
408-
symbolTable, shardingRule, mesh);
418+
AxesPerFactor commonAxesPerFactor =
419+
findCommonAxes(inShardings, outShardings, shardingProjection,
420+
shardingRule, getTensorSizes(op), symbolTable, mesh);
421+
// TODO(b/446833985): Return common axes factors also when the sharding
422+
// projection have overflow axes.
423+
if (commonAxesPerFactor.empty()) {
424+
return {};
425+
}
426+
427+
UpdateTensorShardings updateTensorShardings(shardingRule.getNumOperands(),
428+
shardingRule.getNumResults());
429+
for (const auto& [index, axes] : llvm::enumerate(commonAxesPerFactor)) {
430+
// TODO(enver): Add unit tests to test overflow axes are cleared after
431+
// handling the case that some factors have overflow axes.
432+
updateTensorShardings |=
433+
shardingProjection.updateSharding(index, axes, /*overflowAxes=*/{});
434+
}
435+
insertExplicitReshards(op, inShardings, outShardings, shardingProjection,
436+
updateTensorShardings, rewriter, shardingRule,
437+
symbolTable, mesh);
438+
439+
return getReductionAxes(commonAxesPerFactor, shardingRule);
409440
}
410441

411442
TypeSwitch<Operation*>(op)
@@ -422,9 +453,6 @@ SmallVector<AxisRefAttr> processOp(Operation* op,
422453
return {};
423454
}
424455

425-
ShardingProjection shardingProjection = ShardingProjection::build(
426-
inShardings, outShardings, shardingRule, mesh.attr(),
427-
/*closedIfMissing=*/true);
428456
// TODO(enver): Factor out finding common axes per factor. Share logic with
429457
// getCompatibleFactorShardings.
430458
SmallVector<AxisRefAttr> reductionAxes;

0 commit comments

Comments
 (0)