Skip to content

Commit 2bb09b4

Browse files
ekayaaslancopybara-github
authored andcommitted
Support funcs and calls along side with named computations on constant splitter.
It is to prepare for pushing shardy inlining past constant splitter. PiperOrigin-RevId: 900743953
1 parent dfb6bde commit 2bb09b4

File tree

2 files changed

+742
-34
lines changed

2 files changed

+742
-34
lines changed

shardy/dialect/sdy/transforms/import/constant_or_scalar_splitter.cc

Lines changed: 60 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ limitations under the License.
2626
#include "mlir/IR/OpDefinition.h"
2727
#include "mlir/IR/Operation.h"
2828
#include "mlir/IR/PatternMatch.h"
29+
#include "mlir/IR/SymbolTable.h"
2930
#include "mlir/IR/Value.h"
3031
#include "mlir/IR/Visitors.h"
3132
#include "mlir/Interfaces/SideEffectInterfaces.h"
@@ -48,6 +49,7 @@ namespace sdy {
4849

4950
namespace {
5051

52+
using func::CallOp;
5153
using func::FuncOp;
5254

5355
void cloneShardingGroupUsers(OpResult opResult, IRMapping& mapping,
@@ -66,10 +68,16 @@ void cloneShardingGroupUsers(OpResult opResult, IRMapping& mapping,
6668
// - A broadcast, reshape or slice op.
6769
// - An elementwise op.
6870
// - A named computation all operations are constant preserving.
71+
// - A call to a func that all operations are constant preserving.
6972
// Assumes the op is not constant or iota.
7073
bool isConstantPreserving(
7174
Operation* op,
72-
const llvm::SmallDenseSet<StringRef>& nonConstantNamedComputationOps) {
75+
const llvm::SmallDenseSet<StringRef>& nonConstantNamedComputationOps,
76+
const llvm::SmallDenseSet<StringRef>& nonConstFuncOps,
77+
const SymbolTable& symbolTable) {
78+
if (CallOp callOp = dyn_cast<CallOp>(op)) {
79+
return !nonConstFuncOps.contains(getOriginalFuncName(callOp, symbolTable));
80+
}
7381
if (auto namedComputationOp = dyn_cast<NamedComputationOp>(op)) {
7482
return !nonConstantNamedComputationOps.contains(
7583
namedComputationOp.getName());
@@ -93,11 +101,14 @@ bool isConstantPreserving(
93101
// constants, that is, exist in `constantOps`.
94102
bool isConstantExpression(
95103
Operation* op, const llvm::SetVector<Operation*>& constantOps,
96-
const llvm::SmallDenseSet<StringRef>& nonConstantNamedComputationOps) {
104+
const llvm::SmallDenseSet<StringRef>& nonConstantNamedComputationOps,
105+
const llvm::SmallDenseSet<StringRef>& nonConstFuncOps,
106+
const SymbolTable& symbolTable) {
97107
if (isa<ConstantOp, stablehlo::IotaOp>(op)) {
98108
return true;
99109
}
100-
return isConstantPreserving(op, nonConstantNamedComputationOps) &&
110+
return isConstantPreserving(op, nonConstantNamedComputationOps,
111+
nonConstFuncOps, symbolTable) &&
101112
llvm::all_of(op->getOperands(), [&](Value operand) {
102113
return operand.getDefiningOp() &&
103114
constantOps.contains(operand.getDefiningOp());
@@ -117,20 +128,26 @@ bool isScalarExpansion(Operation* op) {
117128
// Recursively clones all operands of the given op, that are not already mapped
118129
// in `mapping`, and finally clones the op itself. We do not clone scalars as
119130
// they do not get sharded.
120-
void cloneSubComputation(OpResult opResult, IRMapping& mapping) {
131+
void cloneSubComputation(OpResult opResult, IRMapping& mapping,
132+
SymbolTable& symbolTable) {
121133
if (isScalar(opResult) || mapping.lookupOrNull(opResult)) {
122134
return;
123135
}
124136
Operation* op = opResult.getOwner();
125137
for (Value operand : op->getOperands()) {
126138
if (auto defOpResult = dyn_cast<OpResult>(operand)) {
127-
cloneSubComputation(defOpResult, mapping);
139+
cloneSubComputation(defOpResult, mapping, symbolTable);
128140
}
129141
}
130142

131143
// This will insert the cloned op right before the original op.
132144
OpBuilder builder(op);
133-
builder.clone(*op, mapping);
145+
Operation* clonedOp = builder.clone(*op, mapping);
146+
if (CallOp callOp = dyn_cast<CallOp>(clonedOp)) {
147+
FuncOp funcOp = symbolTable.lookup<FuncOp>(callOp.getCallee());
148+
callOp.setCallee(
149+
symbolTable.insert(cloneFuncRecursively(funcOp, symbolTable)));
150+
}
134151
cloneShardingGroupUsers(opResult, mapping, builder);
135152
}
136153

@@ -139,18 +156,19 @@ void cloneSubComputation(OpResult opResult, IRMapping& mapping) {
139156
// sharded.
140157
//
141158
// Returns the cloned op result.
142-
Value cloneSubComputation(OpResult opResult) {
159+
Value cloneSubComputation(OpResult opResult, SymbolTable& symbolTable) {
143160
if (isScalar(opResult)) {
144161
return opResult;
145162
}
146163
IRMapping mapping;
147-
cloneSubComputation(opResult, mapping);
164+
cloneSubComputation(opResult, mapping, symbolTable);
148165
return mapping.lookup(opResult);
149166
}
150167

151168
void cloneSubComputationOnOperands(
152169
Operation* op, const llvm::SetVector<Operation*>& constantOps,
153-
const llvm::SetVector<Operation*>& scalarExpansionOps) {
170+
const llvm::SetVector<Operation*>& scalarExpansionOps,
171+
SymbolTable& symbolTable) {
154172
for (OpOperand& operand : op->getOpOperands()) {
155173
if (auto defOpResult = dyn_cast<OpResult>(operand.get());
156174
defOpResult && (constantOps.contains(defOpResult.getOwner()) ||
@@ -160,38 +178,45 @@ void cloneSubComputationOnOperands(
160178
// `defOpResult`, and replace the `operand` with the cloned defining
161179
// op. The cloned constant sub-computation has only one user `op`,
162180
// so that it is isolated from the rest of the computation.
163-
operand.set(cloneSubComputation(defOpResult));
181+
operand.set(cloneSubComputation(defOpResult, symbolTable));
164182
}
165183
}
166184
}
167185

168-
void processOp(Operation* op, llvm::SetVector<Operation*>& constantOps,
186+
void processOp(Operation* op, FuncOp funcOp,
187+
llvm::SetVector<Operation*>& constantOps,
169188
llvm::SetVector<Operation*>& scalarExpansionOps,
170-
llvm::SmallDenseSet<StringRef>& nonConstantNamedComputationOps) {
171-
if (isa<ShardingGroupOp>(op)) {
189+
llvm::SmallDenseSet<StringRef>& nonConstantNamedComputationOps,
190+
llvm::SmallDenseSet<StringRef>& nonConstFuncOps,
191+
SymbolTable& symbolTable) {
192+
if (isa<FuncOp, ShardingGroupOp>(op)) {
172193
return;
173194
}
174-
if (isConstantExpression(op, constantOps, nonConstantNamedComputationOps)) {
195+
if (isConstantExpression(op, constantOps, nonConstantNamedComputationOps,
196+
nonConstFuncOps, symbolTable)) {
175197
constantOps.insert(op);
176198
return;
177199
}
178200
// NOTE: There are cases that op is an constant expression but may not pass
179201
// the following check such as constant and iota ops. That is fine because if
180202
// the op is a constant expression it is a stronger condition than being just
181-
// constant preserving and it does not make the parent named computation
182-
// non-const, and at this point, it is guaranteed that the op is not constant
183-
// expression.
184-
if (!isConstantPreserving(op, nonConstantNamedComputationOps) &&
203+
// constant preserving and it does not make the parent named computation or
204+
// the `funcOp` non-const, and at this point, it is guaranteed that the op is
205+
// not constant expression.
206+
if (!isConstantPreserving(op, nonConstantNamedComputationOps, nonConstFuncOps,
207+
symbolTable) &&
185208
!op->hasTrait<OpTrait::IsTerminator>()) {
186209
if (auto namedCompuationOp = op->getParentOfType<NamedComputationOp>()) {
187210
nonConstantNamedComputationOps.insert(namedCompuationOp.getName());
188211
}
212+
nonConstFuncOps.insert(getOriginalFuncName(funcOp));
189213
}
190214
if (isScalarExpansion(op)) {
191215
scalarExpansionOps.insert(op);
192216
return;
193217
}
194-
cloneSubComputationOnOperands(op, constantOps, scalarExpansionOps);
218+
cloneSubComputationOnOperands(op, constantOps, scalarExpansionOps,
219+
symbolTable);
195220
}
196221

197222
// Converts stablehlo::ConstantOp to sdy::ConstantOp.
@@ -240,16 +265,16 @@ struct ConstantOrScalarSplitterPass
240265
}
241266
}
242267

243-
// Assumes that the `NamedComputationOp` of the region are already walked, and
244-
// skips walking on them.
245268
void walkOnRegion(
246-
mlir::Region& region,
247-
llvm::SmallDenseSet<StringRef>& nonConstantNamedComputationOps) {
269+
mlir::Region& region, FuncOp funcOp,
270+
llvm::SmallDenseSet<StringRef>& nonConstantNamedComputationOps,
271+
llvm::SmallDenseSet<StringRef>& nonConstFuncOps,
272+
SymbolTable& symbolTable) {
248273
llvm::SetVector<Operation*> constantOps;
249274
llvm::SetVector<Operation*> scalarExpansionOps;
250275
region.walk<WalkOrder::PreOrder>([&](Operation* op) {
251-
processOp(op, constantOps, scalarExpansionOps,
252-
nonConstantNamedComputationOps);
276+
processOp(op, funcOp, constantOps, scalarExpansionOps,
277+
nonConstantNamedComputationOps, nonConstFuncOps, symbolTable);
253278
// Skip walking on the `NamedComputationOp`.
254279
if (isa<NamedComputationOp>(op)) {
255280
return WalkResult::skip();
@@ -267,6 +292,7 @@ struct ConstantOrScalarSplitterPass
267292

268293
void runOnOperation() final {
269294
ModuleOp moduleOp = getOperation();
295+
SymbolTable symbolTable(moduleOp);
270296

271297
// We first convert any `stablehlo::ConstantOp` to an `sdy::ConstantOp`, so
272298
// that constants won't be deduped via folding.
@@ -276,15 +302,15 @@ struct ConstantOrScalarSplitterPass
276302

277303
// Then we split constant sub-computations for each non-constant user.
278304
llvm::SmallDenseSet<StringRef> nonConstantNamedComputationOps;
279-
// Iterate on a post-order of NamedComputationOp blocks.
280-
moduleOp.walk([&](NamedComputationOp namedComputationOp) {
281-
walkOnRegion(namedComputationOp.getBody(),
282-
nonConstantNamedComputationOps);
283-
});
284-
// Iterate order does not matter. Funcs do not call each other. The calls
285-
// are inlined to NamedComputationOps.
286-
moduleOp.walk([&](FuncOp funcOp) {
287-
walkOnRegion(funcOp.getBody(), nonConstantNamedComputationOps);
305+
llvm::SmallDenseSet<StringRef> nonConstFuncOps;
306+
iterateFuncs(moduleOp, [&](FuncOp funcOp) {
307+
funcOp.walk([&](NamedComputationOp namedComputationOp) {
308+
walkOnRegion(namedComputationOp.getBody(), funcOp,
309+
nonConstantNamedComputationOps, nonConstFuncOps,
310+
symbolTable);
311+
});
312+
walkOnRegion(funcOp.getBody(), funcOp, nonConstantNamedComputationOps,
313+
nonConstFuncOps, symbolTable);
288314
});
289315
}
290316

0 commit comments

Comments
 (0)