@@ -16,6 +16,7 @@ limitations under the License.
1616#include < cassert>
1717#include < utility>
1818
19+ #include " llvm/ADT/DenseSet.h"
1920#include " llvm/ADT/STLExtras.h"
2021#include " mlir/Dialect/Func/IR/FuncOps.h"
2122#include " mlir/IR/Builders.h"
@@ -55,20 +56,31 @@ void cloneShardingGroupUsers(OpResult opResult, IRMapping& mapping,
5556 }
5657 }
5758}
59+
5860// A constant preserving op is an op that is considered a constant expression if
5961// it is pure and all its results can be considered as constant expressions
6062// given all its operands are constant expressions, for which it holds if the
6163// given op is either:
6264// - A broadcast, reshape or slice op.
6365// - An elementwise op.
66+ // - A named computation all operations are constant preserving.
6467// Assumes the op is not constant or iota.
65- bool isConstantPreserving (Operation* op) {
68+ bool isConstantPreserving (
69+ Operation* op,
70+ const llvm::SmallDenseSet<StringRef>& nonConstantNamedComputationOps) {
71+ if (!isPure (op)) {
72+ return false ;
73+ }
6674 if (isa<stablehlo::BroadcastInDimOp, stablehlo::ReshapeOp,
6775 stablehlo::SliceOp>(op)) {
68- return isPure (op) ;
76+ return true ;
6977 }
7078 if (isElementwise (op)) {
71- return isPure (op);
79+ return true ;
80+ }
81+ if (auto namedComputationOp = dyn_cast<NamedComputationOp>(op)) {
82+ return !nonConstantNamedComputationOps.contains (
83+ namedComputationOp.getName ());
7284 }
7385 return false ;
7486}
@@ -77,12 +89,13 @@ bool isConstantPreserving(Operation* op) {
7789// - A constant or iota op.
7890// - A constant preserving op. (see isConstantPreserving) and all operands are
7991// constants, that is, exist in `constantOps`.
80- bool isConstantExpression (Operation* op,
81- const llvm::SetVector<Operation*>& constantOps) {
92+ bool isConstantExpression (
93+ Operation* op, const llvm::SetVector<Operation*>& constantOps,
94+ const llvm::SmallDenseSet<StringRef>& nonConstantNamedComputationOps) {
8295 if (isa<ConstantOp, stablehlo::IotaOp>(op)) {
8396 return true ;
8497 }
85- return isConstantPreserving (op) &&
98+ return isConstantPreserving (op, nonConstantNamedComputationOps ) &&
8699 llvm::all_of (op->getOperands (), [&](Value operand) {
87100 return operand.getDefiningOp () &&
88101 constantOps.contains (operand.getDefiningOp ());
@@ -151,14 +164,27 @@ void cloneSubComputationOnOperands(
151164}
152165
153166void processOp (Operation* op, llvm::SetVector<Operation*>& constantOps,
154- llvm::SetVector<Operation*>& scalarExpansionOps) {
167+ llvm::SetVector<Operation*>& scalarExpansionOps,
168+ llvm::SmallDenseSet<StringRef>& nonConstantNamedComputationOps) {
155169 if (isa<ShardingGroupOp>(op)) {
156170 return ;
157171 }
158- if (isConstantExpression (op, constantOps)) {
172+ if (isConstantExpression (op, constantOps, nonConstantNamedComputationOps )) {
159173 constantOps.insert (op);
160174 return ;
161175 }
176+ // NOTE: There are cases that op is an constant expression but may not pass
177+ // the following check such as constant and iota ops. That is fine because if
178+ // the op is a constant expression it is a stronger condition than being just
179+ // constant preserving and it does not make the parent named computation
180+ // non-const, and at this point, it is guaranteed that the op is not constant
181+ // expression.
182+ if (!isConstantPreserving (op, nonConstantNamedComputationOps) &&
183+ !op->hasTrait <OpTrait::IsTerminator>()) {
184+ if (auto namedCompuationOp = op->getParentOfType <NamedComputationOp>()) {
185+ nonConstantNamedComputationOps.insert (namedCompuationOp.getName ());
186+ }
187+ }
162188 if (isScalarExpansion (op)) {
163189 scalarExpansionOps.insert (op);
164190 return ;
@@ -224,6 +250,7 @@ struct ConstantOrScalarSplitterPass
224250 // Then we split constant sub-computations for each non-constant user.
225251 SmallVector<llvm::SetVector<Operation*>> constantOps;
226252 llvm::SetVector<Operation*> scalarExpansionOps;
253+ llvm::SmallDenseSet<StringRef> nonConstantNamedComputationOps;
227254 constantOps.emplace_back ();
228255 funcOp.walk <WalkOrder::PreOrder>([&](Operation* op) {
229256 // Becuase it is a preorder walk, it visits the NamedComputationOp before
@@ -236,13 +263,15 @@ struct ConstantOrScalarSplitterPass
236263 constantOps.emplace_back ();
237264 return ;
238265 }
239- processOp (op, constantOps.back (), scalarExpansionOps);
266+ processOp (op, constantOps.back (), scalarExpansionOps,
267+ nonConstantNamedComputationOps);
240268 if (isa<sdy::ReturnOp>(op) &&
241269 isa<NamedComputationOp>(op->getParentOp ())) {
242270 eraseUnusedOpsAlongWithItsShardingGroupUsers (
243271 llvm::reverse (constantOps.back ()));
244272 constantOps.pop_back ();
245- processOp (op->getParentOp (), constantOps.back (), scalarExpansionOps);
273+ processOp (op->getParentOp (), constantOps.back (), scalarExpansionOps,
274+ nonConstantNamedComputationOps);
246275 return ;
247276 }
248277 });
0 commit comments