Skip to content

Commit 615c108

Browse files
Use potentially modified sharding projection to decide on reduction factor shardings on the default/minimal version.
PiperOrigin-RevId: 812780059
1 parent b426498 commit 615c108

1 file changed

Lines changed: 4 additions & 7 deletions

File tree

shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -158,17 +158,14 @@ void insertExplicitReshardsOnDataFlowOp(ShardableDataFlowOpInterface& op,
158158
// return %reshard : tensor<4x8xf32>
159159
// ```
160160
template <class OpTy>
161-
void processDot(OpTy op, ArrayRef<TensorShardingAttr> inShardings,
161+
void processDot(OpTy op, ShardingProjection& shardingProjection,
162162
ArrayRef<TensorShardingAttr> outShardings, IRRewriter& rewriter,
163163
const SymbolTable& symbolTable, OpShardingRuleAttr shardingRule,
164164
const Mesh& mesh) {
165165
if (outShardings.empty()) {
166166
// Result doesn't have a sharding.
167167
return;
168168
}
169-
ShardingProjection shardingProjection =
170-
ShardingProjection::build(inShardings, outShardings, shardingRule,
171-
mesh.attr(), /*closedIfMissing=*/true);
172169

173170
const TensorFactorShardings& lhsSharding = shardingProjection.getOperand(0);
174171
const TensorFactorShardings& rhsSharding = shardingProjection.getOperand(1);
@@ -449,11 +446,11 @@ SmallVector<AxisRefAttr> processOp(Operation* op,
449446

450447
TypeSwitch<Operation*>(op)
451448
.Case<stablehlo::DotOp>([&](stablehlo::DotOp dotOp) {
452-
processDot(dotOp, inShardings, outShardings, rewriter, symbolTable,
453-
shardingRule, mesh);
449+
processDot(dotOp, shardingProjection, outShardings, rewriter,
450+
symbolTable, shardingRule, mesh);
454451
})
455452
.Case<stablehlo::DotGeneralOp>([&](stablehlo::DotGeneralOp dotGeneralOp) {
456-
processDot(dotGeneralOp, inShardings, outShardings, rewriter,
453+
processDot(dotGeneralOp, shardingProjection, outShardings, rewriter,
457454
symbolTable, shardingRule, mesh);
458455
});
459456

0 commit comments

Comments
 (0)