Skip to content

Commit 7e64e3f

Browse files
Expand constant expressions on named computations with all constant like ops.
PiperOrigin-RevId: 804851791
1 parent a058dbb commit 7e64e3f

2 files changed

Lines changed: 219 additions & 53 deletions

File tree

shardy/dialect/sdy/transforms/import/constant_or_scalar_splitter.cc

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616
#include <cassert>
1717
#include <utility>
1818

19+
#include "llvm/ADT/DenseSet.h"
1920
#include "llvm/ADT/STLExtras.h"
2021
#include "mlir/Dialect/Func/IR/FuncOps.h"
2122
#include "mlir/IR/Builders.h"
@@ -55,20 +56,31 @@ void cloneShardingGroupUsers(OpResult opResult, IRMapping& mapping,
5556
}
5657
}
5758
}
59+
5860
// A constant preserving op is an op that is considered a constant expression if
5961
// it is pure and all its results can be considered as constant expressions
6062
// given all its operands are constant expressions, for which it holds if the
6163
// given op is either:
6264
// - A broadcast, reshape or slice op.
6365
// - An elementwise op.
66+
// - A named computation all operations are constant preserving.
6467
// Assumes the op is not constant or iota.
65-
bool isConstantPreserving(Operation* op) {
68+
bool isConstantPreserving(
69+
Operation* op,
70+
const llvm::SmallDenseSet<StringRef>& nonConstantNamedComputationOps) {
71+
if (!isPure(op)) {
72+
return false;
73+
}
6674
if (isa<stablehlo::BroadcastInDimOp, stablehlo::ReshapeOp,
6775
stablehlo::SliceOp>(op)) {
68-
return isPure(op);
76+
return true;
6977
}
7078
if (isElementwise(op)) {
71-
return isPure(op);
79+
return true;
80+
}
81+
if (auto namedComputationOp = dyn_cast<NamedComputationOp>(op)) {
82+
return !nonConstantNamedComputationOps.contains(
83+
namedComputationOp.getName());
7284
}
7385
return false;
7486
}
@@ -77,12 +89,13 @@ bool isConstantPreserving(Operation* op) {
7789
// - A constant or iota op.
7890
// - A constant preserving op. (see isConstantPreserving) and all operands are
7991
// constants, that is, exist in `constantOps`.
80-
bool isConstantExpression(Operation* op,
81-
const llvm::SetVector<Operation*>& constantOps) {
92+
bool isConstantExpression(
93+
Operation* op, const llvm::SetVector<Operation*>& constantOps,
94+
const llvm::SmallDenseSet<StringRef>& nonConstantNamedComputationOps) {
8295
if (isa<ConstantOp, stablehlo::IotaOp>(op)) {
8396
return true;
8497
}
85-
return isConstantPreserving(op) &&
98+
return isConstantPreserving(op, nonConstantNamedComputationOps) &&
8699
llvm::all_of(op->getOperands(), [&](Value operand) {
87100
return operand.getDefiningOp() &&
88101
constantOps.contains(operand.getDefiningOp());
@@ -151,14 +164,27 @@ void cloneSubComputationOnOperands(
151164
}
152165

153166
void processOp(Operation* op, llvm::SetVector<Operation*>& constantOps,
154-
llvm::SetVector<Operation*>& scalarExpansionOps) {
167+
llvm::SetVector<Operation*>& scalarExpansionOps,
168+
llvm::SmallDenseSet<StringRef>& nonConstantNamedComputationOps) {
155169
if (isa<ShardingGroupOp>(op)) {
156170
return;
157171
}
158-
if (isConstantExpression(op, constantOps)) {
172+
if (isConstantExpression(op, constantOps, nonConstantNamedComputationOps)) {
159173
constantOps.insert(op);
160174
return;
161175
}
176+
// NOTE: There are cases that op is an constant expression but may not pass
177+
// the following check such as constant and iota ops. That is fine because if
178+
// the op is a constant expression it is a stronger condition than being just
179+
// constant preserving and it does not make the parent named computation
180+
// non-const, and at this point, it is guaranteed that the op is not constant
181+
// expression.
182+
if (!isConstantPreserving(op, nonConstantNamedComputationOps) &&
183+
!op->hasTrait<OpTrait::IsTerminator>()) {
184+
if (auto namedCompuationOp = op->getParentOfType<NamedComputationOp>()) {
185+
nonConstantNamedComputationOps.insert(namedCompuationOp.getName());
186+
}
187+
}
162188
if (isScalarExpansion(op)) {
163189
scalarExpansionOps.insert(op);
164190
return;
@@ -224,6 +250,7 @@ struct ConstantOrScalarSplitterPass
224250
// Then we split constant sub-computations for each non-constant user.
225251
SmallVector<llvm::SetVector<Operation*>> constantOps;
226252
llvm::SetVector<Operation*> scalarExpansionOps;
253+
llvm::SmallDenseSet<StringRef> nonConstantNamedComputationOps;
227254
constantOps.emplace_back();
228255
funcOp.walk<WalkOrder::PreOrder>([&](Operation* op) {
229256
// Becuase it is a preorder walk, it visits the NamedComputationOp before
@@ -236,13 +263,15 @@ struct ConstantOrScalarSplitterPass
236263
constantOps.emplace_back();
237264
return;
238265
}
239-
processOp(op, constantOps.back(), scalarExpansionOps);
266+
processOp(op, constantOps.back(), scalarExpansionOps,
267+
nonConstantNamedComputationOps);
240268
if (isa<sdy::ReturnOp>(op) &&
241269
isa<NamedComputationOp>(op->getParentOp())) {
242270
eraseUnusedOpsAlongWithItsShardingGroupUsers(
243271
llvm::reverse(constantOps.back()));
244272
constantOps.pop_back();
245-
processOp(op->getParentOp(), constantOps.back(), scalarExpansionOps);
273+
processOp(op->getParentOp(), constantOps.back(), scalarExpansionOps,
274+
nonConstantNamedComputationOps);
246275
return;
247276
}
248277
});

0 commit comments

Comments
 (0)