Skip to content

Commit 4533c31

Browse files
Refactor to unify insert explicit reshards on op flows for default/minimal and full version.
It prepares for further refactorings to unify parts of the flow. PiperOrigin-RevId: 811354695
1 parent 17e1e18 commit 4533c31

3 files changed

Lines changed: 105 additions & 93 deletions

File tree

shardy/dialect/sdy/transforms/export/explicit_reshards_util.cc

Lines changed: 4 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,6 @@ limitations under the License.
4545
namespace mlir {
4646
namespace sdy {
4747

48-
namespace {
49-
50-
// Returns true iff any tensor factor sharding has non-empty overflow axes.
5148
bool hasOverflowAxes(const ShardingProjection& shardingProjection) {
5249
for (const TensorFactorShardings& tensorFactorSharding :
5350
llvm::concat<const TensorFactorShardings>(
@@ -62,6 +59,7 @@ bool hasOverflowAxes(const ShardingProjection& shardingProjection) {
6259
return false;
6360
}
6461

62+
namespace {
6563
bool hasShardedPermutationFactors(
6664
const TensorFactorShardings& tensorFactorSharding,
6765
OpShardingRuleAttr shardingRule) {
@@ -157,44 +155,8 @@ bool shouldReshardToCommonMesh(TensorShardingAttr sharding, const Mesh& mesh,
157155
sharding.getMesh(symbolTable).getDeviceIds() !=
158156
mesh.attr().getDeviceIds();
159157
}
158+
} // namespace
160159

161-
// Insert explicit reshards for operands and results that change by
162-
// the given `shardingProjection` for a given `op`. The reshards are inserted
163-
// only to make the given operation compatible.
164-
//
165-
// For example,
166-
//
167-
// ```mlir
168-
// %arg0: tensor<8x32xf32> { sdy.sharding = @mesh, [{}, {"y"}]>}
169-
// %arg1: tensor<32x16xf32> { sdy.sharding = <@mesh, [{"y"}, {"x"}]>}
170-
// %0 = stablehlo.dot %arg0, %arg1 { sdy.sharding = <@mesh, [{"x"}, {}]>,
171-
// sdy.sharding_rule = <([i, k], [k, j])->([i, j])> }
172-
// %1 = stablehlo.negate %0 {sdy.sharding = <@mesh, [{"x"}, {}]>
173-
// return %1
174-
// ```
175-
//
176-
// after a call on the stablehlo.dot operation, by the sharding projection,
177-
// i: {}, j: {}, k: {"y"}, the module becomes:
178-
//
179-
// ```mlir
180-
// %arg0: tensor<8x32xf32> { sdy.sharding = @mesh, [{}, {"y"}]>}
181-
// %arg1: tensor<32x16xf32> { sdy.sharding = <@mesh, [{"y"}, {"x"}]>}
182-
// %0 = stablehlo.reshard %arg1 {sdy.sharding = <@mesh, [{"y"}, {}]>}
183-
// %1 = stablehlo.dot %arg0, %0 { sdy.sharding = <@mesh, [{}, {}]>,
184-
// sdy.sharding_rule = <([i, k], [k, j])->([i, j])> }
185-
// %2 = stablehlo.reshard %1 {sdy.sharding = <@mesh, [{"x"}, {}]>}
186-
// %3 = stablehlo.negate %2 {sdy.sharding = <@mesh, [{"x"}, {}]>
187-
// return %3
188-
// ```
189-
//
190-
// In the above example, note that the operand and result shardings for
191-
// stablehlo.negate op remained unchanged.
192-
//
193-
// Assumes factor shardings do not have overflow axes.
194-
// TODO(enver): Handle the case when some factor shardings have overflow axes.
195-
//
196-
// Assumes all tensor shardings have the same mesh as `mesh` on axes but may be
197-
// different on device order.
198160
void insertExplicitReshards(Operation* op,
199161
ArrayRef<TensorShardingAttr> inShardings,
200162
ArrayRef<TensorShardingAttr> outShardings,
@@ -223,6 +185,7 @@ void insertExplicitReshards(Operation* op,
223185
}
224186
}
225187

188+
namespace {
226189
struct FactorAxesPair {
227190
constexpr static int64_t kEmptyFactorIndex = -1;
228191
constexpr static int64_t kTombstoneFactorIndex = -2;
@@ -793,6 +756,7 @@ void distributeAxisRefsToBatchingFactors(
793756
}
794757
}
795758
}
759+
} // namespace
796760

797761
// Assumes there are no overflow axes.
798762
//
@@ -855,8 +819,6 @@ SmallVector<int64_t> getTensorSizes(Operation* op) {
855819
return tensorSizes;
856820
}
857821

858-
// Returns reduction axes that are the union of all axes on reduction factors.
859-
// The result axes are not necessarilly canonicalized.
860822
SmallVector<AxisRefAttr> getReductionAxes(const AxesPerFactor& axesPerFactor,
861823
OpShardingRuleAttr shardingRule) {
862824
SmallVector<AxisRefAttr> reductionAxes;
@@ -865,7 +827,6 @@ SmallVector<AxisRefAttr> getReductionAxes(const AxesPerFactor& axesPerFactor,
865827
}
866828
return reductionAxes;
867829
}
868-
} // namespace
869830

870831
TensorShardingAttr insertAllReduceIfUnreducedToReplicated(
871832
OpOperand& use, TensorShardingAttr sourceSharding,
@@ -952,41 +913,5 @@ void insertAllReducesForReductionFactors(Operation* op,
952913
}
953914
}
954915

955-
SmallVector<AxisRefAttr> insertExplicitReshardsOnOp(
956-
Operation* op, ArrayRef<TensorShardingAttr> inShardings,
957-
ArrayRef<TensorShardingAttr> outShardings, IRRewriter& rewriter,
958-
const SymbolTable& symbolTable, OpShardingRuleAttr shardingRule,
959-
const Mesh& mesh) {
960-
ShardingProjection shardingProjection = ShardingProjection::build(
961-
inShardings, outShardings, shardingRule, mesh.attr(),
962-
/*closedIfMissing=*/true);
963-
964-
UpdateTensorShardings updateTensorShardings(shardingRule.getNumOperands(),
965-
shardingRule.getNumResults());
966-
967-
// Return without inserting reshards if any factor sharding has overflow
968-
// axes. This case is not handled yet.
969-
// TODO(b/446833985): Handle the case when factor shardings have overflow
970-
// axes.
971-
if (hasOverflowAxes(shardingProjection)) {
972-
return {};
973-
}
974-
975-
AxesPerFactor commonAxesPerFactor =
976-
findCommonAxes(inShardings, outShardings, shardingProjection,
977-
shardingRule, getTensorSizes(op), symbolTable, mesh);
978-
for (const auto& [index, axes] : llvm::enumerate(commonAxesPerFactor)) {
979-
// TODO(enver): Add unit tests to test overflow axes are cleared after
980-
// handling the case that some factors have overflow axes.
981-
updateTensorShardings |=
982-
shardingProjection.updateSharding(index, axes, /*overflowAxes=*/{});
983-
}
984-
insertExplicitReshards(op, inShardings, outShardings, shardingProjection,
985-
updateTensorShardings, rewriter, shardingRule,
986-
symbolTable, mesh);
987-
988-
return getReductionAxes(commonAxesPerFactor, shardingRule);
989-
}
990-
991916
} // namespace sdy
992917
} // namespace mlir

shardy/dialect/sdy/transforms/export/explicit_reshards_util.h

Lines changed: 66 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,63 @@ ArrayRef<AxisRefAttr> getUnreducedAxes(TensorShardingAttr sharding);
7878
// empty axes.
7979
ArrayRef<AxisRefAttr> getUnreducedAxes(Value value);
8080

81+
// Returns a concatenated array of operand and result tensor sizes.
82+
SmallVector<int64_t> getTensorSizes(Operation* op);
83+
84+
// Returns reduction axes that are the union of all axes on reduction factors.
85+
// The result axes are not necessarilly canonicalized.
86+
SmallVector<AxisRefAttr> getReductionAxes(const AxesPerFactor& axesPerFactor,
87+
OpShardingRuleAttr shardingRule);
88+
89+
// Returns true iff any tensor factor sharding has non-empty overflow axes.
90+
bool hasOverflowAxes(const ShardingProjection& shardingProjection);
91+
92+
// Insert explicit reshards for operands and results that change by
93+
// the given `shardingProjection` for a given `op`. The reshards are inserted
94+
// only to make the given operation compatible.
95+
//
96+
// For example,
97+
//
98+
// ```mlir
99+
// %arg0: tensor<8x32xf32> { sdy.sharding = @mesh, [{}, {"y"}]>}
100+
// %arg1: tensor<32x16xf32> { sdy.sharding = <@mesh, [{"y"}, {"x"}]>}
101+
// %0 = stablehlo.dot %arg0, %arg1 { sdy.sharding = <@mesh, [{"x"}, {}]>,
102+
// sdy.sharding_rule = <([i, k], [k, j])->([i, j])> }
103+
// %1 = stablehlo.negate %0 {sdy.sharding = <@mesh, [{"x"}, {}]>
104+
// return %1
105+
// ```
106+
//
107+
// after a call on the stablehlo.dot operation, by the sharding projection,
108+
// i: {}, j: {}, k: {"y"}, the module becomes:
109+
//
110+
// ```mlir
111+
// %arg0: tensor<8x32xf32> { sdy.sharding = @mesh, [{}, {"y"}]>}
112+
// %arg1: tensor<32x16xf32> { sdy.sharding = <@mesh, [{"y"}, {"x"}]>}
113+
// %0 = stablehlo.reshard %arg1 {sdy.sharding = <@mesh, [{"y"}, {}]>}
114+
// %1 = stablehlo.dot %arg0, %0 { sdy.sharding = <@mesh, [{}, {}]>,
115+
// sdy.sharding_rule = <([i, k], [k, j])->([i, j])> }
116+
// %2 = stablehlo.reshard %1 {sdy.sharding = <@mesh, [{"x"}, {}]>}
117+
// %3 = stablehlo.negate %2 {sdy.sharding = <@mesh, [{"x"}, {}]>
118+
// return %3
119+
// ```
120+
//
121+
// In the above example, note that the operand and result shardings for
122+
// stablehlo.negate op remained unchanged.
123+
//
124+
// Assumes factor shardings do not have overflow axes.
125+
// TODO(enver): Handle the case when some factor shardings have overflow axes.
126+
//
127+
// Assumes all tensor shardings have the same mesh as `mesh` on axes but may be
128+
// different on device order.
129+
void insertExplicitReshards(Operation* op,
130+
ArrayRef<TensorShardingAttr> inShardings,
131+
ArrayRef<TensorShardingAttr> outShardings,
132+
const ShardingProjection& shardingProjection,
133+
UpdateTensorShardings updateTensorShardings,
134+
IRRewriter& rewriter,
135+
OpShardingRuleAttr shardingRule,
136+
const SymbolTable& symbolTable, const Mesh& mesh);
137+
81138
// Inserts an `sdy.all-reduce` for each result of `op` if `reductionAxes`
82139
// is non-empty. Assume the followings:
83140
// - All op results have the same unreduced axes.
@@ -87,7 +144,7 @@ void insertAllReducesForReductionFactors(Operation* op,
87144
const Mesh& mesh,
88145
IRRewriter& rewriter);
89146

90-
// Inserts explicit reshards on the operands and results of `op` such that the
147+
// Finds common factor axes on the operands and results of `op` so that the
91148
// sharding of `op` is compatible with its sharding rule.
92149
//
93150
// Refer to the documentation of `InsertExplicitReshardsPass` for more details.
@@ -96,14 +153,15 @@ void insertAllReducesForReductionFactors(Operation* op,
96153
// - All op results have the same unreduced axes.
97154
// - If the op has no results, none of the operands has unreduced axes.
98155
// - Operand and result meshes are the same ignoring device id order.
156+
// - There are no overflow axes.
99157
//
100-
// Returns the union of axes along all the reduction factors which may not be
101-
// canonicalized.
102-
SmallVector<AxisRefAttr> insertExplicitReshardsOnOp(
103-
Operation* op, ArrayRef<TensorShardingAttr> inShardings,
104-
ArrayRef<TensorShardingAttr> outShardings, IRRewriter& rewriter,
105-
const SymbolTable& symbolTable, OpShardingRuleAttr shardingRule,
106-
const Mesh& mesh);
158+
// Guarantees to return a non-empty AxesPerFactor.
159+
AxesPerFactor findCommonAxes(ArrayRef<TensorShardingAttr> inShardings,
160+
ArrayRef<TensorShardingAttr> outShardings,
161+
const ShardingProjection& shardingProjection,
162+
OpShardingRuleAttr shardingRule,
163+
ArrayRef<int64_t> tensorSizes,
164+
const SymbolTable& symbolTable, const Mesh& mesh);
107165

108166
} // namespace sdy
109167
} // namespace mlir

shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -391,21 +391,53 @@ bool isOnFullVersion(Operation* op, const bool enableFullVersion) {
391391
return false;
392392
}
393393

394+
// Inserts explicit reshards on the operands and results of `op` such that the
395+
// sharding of `op` is compatible with its sharding rule.
396+
//
397+
// Refer to the documentation of `InsertExplicitReshardsPass` for more details.
398+
//
394399
// Assume the followings:
395400
// - All op results have the same unreduced axes.
396401
// - If the op has no results, none of the operands has unreduced axes.
402+
// - Operand and result meshes are the same ignoring device id order.
397403
//
398-
// Returns the union of common reducation axes which may not be canonicalized.
404+
// Returns the union of axes along all the reduction factors which may not be
405+
// canonicalized.
399406
SmallVector<AxisRefAttr> processOp(Operation* op,
400407
ArrayRef<TensorShardingAttr> inShardings,
401408
ArrayRef<TensorShardingAttr> outShardings,
402409
IRRewriter& rewriter,
403410
const SymbolTable& symbolTable,
404411
OpShardingRuleAttr shardingRule,
405412
const Mesh& mesh, const bool onFullVersion) {
413+
ShardingProjection shardingProjection = ShardingProjection::build(
414+
inShardings, outShardings, shardingRule, mesh.attr(),
415+
/*closedIfMissing=*/true);
416+
406417
if (onFullVersion) {
407-
return insertExplicitReshardsOnOp(op, inShardings, outShardings, rewriter,
408-
symbolTable, shardingRule, mesh);
418+
// Return without inserting reshards if any factor sharding has overflow
419+
// axes. This case is not handled yet.
420+
// TODO(b/446833985): Handle the case when factor shardings have overflow
421+
// axes.
422+
if (hasOverflowAxes(shardingProjection)) {
423+
return {};
424+
}
425+
AxesPerFactor commonAxesPerFactor =
426+
findCommonAxes(inShardings, outShardings, shardingProjection,
427+
shardingRule, getTensorSizes(op), symbolTable, mesh);
428+
UpdateTensorShardings updateTensorShardings(shardingRule.getNumOperands(),
429+
shardingRule.getNumResults());
430+
for (const auto& [index, axes] : llvm::enumerate(commonAxesPerFactor)) {
431+
// TODO(enver): Add unit tests to test overflow axes are cleared after
432+
// handling the case that some factors have overflow axes.
433+
updateTensorShardings |=
434+
shardingProjection.updateSharding(index, axes, /*overflowAxes=*/{});
435+
}
436+
insertExplicitReshards(op, inShardings, outShardings, shardingProjection,
437+
updateTensorShardings, rewriter, shardingRule,
438+
symbolTable, mesh);
439+
440+
return getReductionAxes(commonAxesPerFactor, shardingRule);
409441
}
410442

411443
TypeSwitch<Operation*>(op)
@@ -422,9 +454,6 @@ SmallVector<AxisRefAttr> processOp(Operation* op,
422454
return {};
423455
}
424456

425-
ShardingProjection shardingProjection = ShardingProjection::build(
426-
inShardings, outShardings, shardingRule, mesh.attr(),
427-
/*closedIfMissing=*/true);
428457
// TODO(enver): Factor out finding common axes per factor. Share logic with
429458
// getCompatibleFactorShardings.
430459
SmallVector<AxisRefAttr> reductionAxes;

0 commit comments

Comments
 (0)