Skip to content

Commit 356dda9

Browse files
wsmosescopybara-github
authored andcommitted
Shardy: don't replicate update if constant input indices to DUS if enzyme comm opts on
PiperOrigin-RevId: 900259511
1 parent 2dd4be9 commit 356dda9

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,12 @@ OpShardingRuleAttr createOpShardingRule(Operation* op,
834834
SmallVector<int64_t> operandDims(
835835
dynamicUpdateSlice->getNumOperands(), kNullDim);
836836
OpShardingRuleBuilder builder(dynamicUpdateSlice);
837+
838+
bool allIndicesConstant = llvm::all_of(dynamicUpdateSlice.getStartIndices(), [](Value v) {
839+
Operation* defOp = v.getDefiningOp();
840+
return defOp && (isa<sdy::ConstantOp>(defOp) || defOp->hasTrait<OpTrait::ConstantLike>());
841+
});
842+
837843
for (auto [dim, dimSizes] :
838844
llvm::enumerate(llvm::zip_equal(operandShape, updateShape))) {
839845
auto [operandDimSize, updateDimSize] = dimSizes;
@@ -846,15 +852,20 @@ OpShardingRuleAttr createOpShardingRule(Operation* op,
846852
//
847853
// Thus, we add a factor for the operand/result slicing
848854
// dimension with kPassThrough type. We also add a unique factor
849-
// for the update with kNeedReplication type.
855+
// for the update with kNeedReplication type (unless all indices
856+
// are constant and we can rely on Enzyme comms opt).
850857
operandDims[0] = dim;
851858
operandDims[1] = kNullDim;
852859
builder.addFactor(operandDims, dim, operandDimSize);
853860

854861
operandDims[0] = kNullDim;
855862
operandDims[1] = dim;
856-
builder.addFactor(operandDims, kNullDim, updateDimSize,
857-
FactorType::kNeedReplication);
863+
if (!allIndicesConstant) {
864+
builder.addFactor(operandDims, kNullDim, updateDimSize,
865+
FactorType::kNeedReplication);
866+
} else {
867+
builder.addFactor(operandDims, kNullDim, updateDimSize);
868+
}
858869
}
859870
}
860871
return builder.build();

shardy/dialect/sdy/transforms/propagation/test/op_sharding_rule_registry.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,15 @@ func.func @dynamic_update_slice(%arg0: tensor<32x4x8xf32>, %arg1: tensor<32x1x2x
585585
return %0 : tensor<32x4x8xf32>
586586
}
587587

588+
// CHECK-LABEL: func @dynamic_update_slice_constant_indices
589+
func.func @dynamic_update_slice_constant_indices(%arg0: tensor<32x4x8xf32>, %arg1: tensor<32x1x2xf32>) -> tensor<32x4x8xf32> {
590+
%0 = stablehlo.constant dense<0> : tensor<i32>
591+
%1 = sdy.constant dense<0> : tensor<i32>
592+
// CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, l], [i, k, m], [], [], [])->([i, j, l]) {i=32, j=4, k=1, l=8, m=2}>
593+
%2 = stablehlo.dynamic_update_slice %arg0, %arg1, %0, %0, %1 : (tensor<32x4x8xf32>, tensor<32x1x2xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<32x4x8xf32>
594+
return %2 : tensor<32x4x8xf32>
595+
}
596+
588597
// CHECK-LABEL: func @fft
589598
func.func @fft(%arg0: tensor<8x32x64xcomplex<f32>>) -> tensor<8x32x64xcomplex<f32>> {
590599
// CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k])->([i, j, k]) {i=8, j=32, k=64} need_replication={j, k}>

0 commit comments

Comments
 (0)