@@ -26,6 +26,7 @@ limitations under the License.
2626#include " mlir/IR/OpDefinition.h"
2727#include " mlir/IR/Operation.h"
2828#include " mlir/IR/PatternMatch.h"
29+ #include " mlir/IR/SymbolTable.h"
2930#include " mlir/IR/Value.h"
3031#include " mlir/IR/Visitors.h"
3132#include " mlir/Interfaces/SideEffectInterfaces.h"
@@ -48,6 +49,7 @@ namespace sdy {
4849
4950namespace {
5051
52+ using func::CallOp;
5153using func::FuncOp;
5254
5355void cloneShardingGroupUsers (OpResult opResult, IRMapping& mapping,
@@ -66,10 +68,16 @@ void cloneShardingGroupUsers(OpResult opResult, IRMapping& mapping,
6668// - A broadcast, reshape or slice op.
6769// - An elementwise op.
6870// - A named computation all operations are constant preserving.
71+ // - A call to a func that all operations are constant preserving.
6972// Assumes the op is not constant or iota.
7073bool isConstantPreserving (
7174 Operation* op,
72- const llvm::SmallDenseSet<StringRef>& nonConstantNamedComputationOps) {
75+ const llvm::SmallDenseSet<StringRef>& nonConstantNamedComputationOps,
76+ const llvm::SmallDenseSet<StringRef>& nonConstFuncOps,
77+ const SymbolTable& symbolTable) {
78+ if (CallOp callOp = dyn_cast<CallOp>(op)) {
79+ return !nonConstFuncOps.contains (getOriginalFuncName (callOp, symbolTable));
80+ }
7381 if (auto namedComputationOp = dyn_cast<NamedComputationOp>(op)) {
7482 return !nonConstantNamedComputationOps.contains (
7583 namedComputationOp.getName ());
@@ -93,11 +101,14 @@ bool isConstantPreserving(
93101// constants, that is, exist in `constantOps`.
94102bool isConstantExpression (
95103 Operation* op, const llvm::SetVector<Operation*>& constantOps,
96- const llvm::SmallDenseSet<StringRef>& nonConstantNamedComputationOps) {
104+ const llvm::SmallDenseSet<StringRef>& nonConstantNamedComputationOps,
105+ const llvm::SmallDenseSet<StringRef>& nonConstFuncOps,
106+ const SymbolTable& symbolTable) {
97107 if (isa<ConstantOp, stablehlo::IotaOp>(op)) {
98108 return true ;
99109 }
100- return isConstantPreserving (op, nonConstantNamedComputationOps) &&
110+ return isConstantPreserving (op, nonConstantNamedComputationOps,
111+ nonConstFuncOps, symbolTable) &&
101112 llvm::all_of (op->getOperands (), [&](Value operand) {
102113 return operand.getDefiningOp () &&
103114 constantOps.contains (operand.getDefiningOp ());
@@ -117,20 +128,26 @@ bool isScalarExpansion(Operation* op) {
117128// Recursively clones all operands of the given op, that are not already mapped
118129// in `mapping`, and finally clones the op itself. We do not clone scalars as
119130// they do not get sharded.
120- void cloneSubComputation (OpResult opResult, IRMapping& mapping) {
131+ void cloneSubComputation (OpResult opResult, IRMapping& mapping,
132+ SymbolTable& symbolTable) {
121133 if (isScalar (opResult) || mapping.lookupOrNull (opResult)) {
122134 return ;
123135 }
124136 Operation* op = opResult.getOwner ();
125137 for (Value operand : op->getOperands ()) {
126138 if (auto defOpResult = dyn_cast<OpResult>(operand)) {
127- cloneSubComputation (defOpResult, mapping);
139+ cloneSubComputation (defOpResult, mapping, symbolTable );
128140 }
129141 }
130142
131143 // This will insert the cloned op right before the original op.
132144 OpBuilder builder (op);
133- builder.clone (*op, mapping);
145+ Operation* clonedOp = builder.clone (*op, mapping);
146+ if (CallOp callOp = dyn_cast<CallOp>(clonedOp)) {
147+ FuncOp funcOp = symbolTable.lookup <FuncOp>(callOp.getCallee ());
148+ callOp.setCallee (
149+ symbolTable.insert (cloneFuncRecursively (funcOp, symbolTable)));
150+ }
134151 cloneShardingGroupUsers (opResult, mapping, builder);
135152}
136153
@@ -139,18 +156,19 @@ void cloneSubComputation(OpResult opResult, IRMapping& mapping) {
139156// sharded.
140157//
141158// Returns the cloned op result.
142- Value cloneSubComputation (OpResult opResult) {
159+ Value cloneSubComputation (OpResult opResult, SymbolTable& symbolTable ) {
143160 if (isScalar (opResult)) {
144161 return opResult;
145162 }
146163 IRMapping mapping;
147- cloneSubComputation (opResult, mapping);
164+ cloneSubComputation (opResult, mapping, symbolTable );
148165 return mapping.lookup (opResult);
149166}
150167
151168void cloneSubComputationOnOperands (
152169 Operation* op, const llvm::SetVector<Operation*>& constantOps,
153- const llvm::SetVector<Operation*>& scalarExpansionOps) {
170+ const llvm::SetVector<Operation*>& scalarExpansionOps,
171+ SymbolTable& symbolTable) {
154172 for (OpOperand& operand : op->getOpOperands ()) {
155173 if (auto defOpResult = dyn_cast<OpResult>(operand.get ());
156174 defOpResult && (constantOps.contains (defOpResult.getOwner ()) ||
@@ -160,38 +178,45 @@ void cloneSubComputationOnOperands(
160178 // `defOpResult`, and replace the `operand` with the cloned defining
161179 // op. The cloned constant sub-computation has only one user `op`,
162180 // so that it is isolated from the rest of the computation.
163- operand.set (cloneSubComputation (defOpResult));
181+ operand.set (cloneSubComputation (defOpResult, symbolTable ));
164182 }
165183 }
166184}
167185
168- void processOp (Operation* op, llvm::SetVector<Operation*>& constantOps,
186+ void processOp (Operation* op, FuncOp funcOp,
187+ llvm::SetVector<Operation*>& constantOps,
169188 llvm::SetVector<Operation*>& scalarExpansionOps,
170- llvm::SmallDenseSet<StringRef>& nonConstantNamedComputationOps) {
171- if (isa<ShardingGroupOp>(op)) {
189+ llvm::SmallDenseSet<StringRef>& nonConstantNamedComputationOps,
190+ llvm::SmallDenseSet<StringRef>& nonConstFuncOps,
191+ SymbolTable& symbolTable) {
192+ if (isa<FuncOp, ShardingGroupOp>(op)) {
172193 return ;
173194 }
174- if (isConstantExpression (op, constantOps, nonConstantNamedComputationOps)) {
195+ if (isConstantExpression (op, constantOps, nonConstantNamedComputationOps,
196+ nonConstFuncOps, symbolTable)) {
175197 constantOps.insert (op);
176198 return ;
177199 }
178200 // NOTE: There are cases that op is an constant expression but may not pass
179201 // the following check such as constant and iota ops. That is fine because if
180202 // the op is a constant expression it is a stronger condition than being just
181- // constant preserving and it does not make the parent named computation
182- // non-const, and at this point, it is guaranteed that the op is not constant
183- // expression.
184- if (!isConstantPreserving (op, nonConstantNamedComputationOps) &&
203+ // constant preserving and it does not make the parent named computation or
204+ // the `funcOp` non-const, and at this point, it is guaranteed that the op is
205+ // not constant expression.
206+ if (!isConstantPreserving (op, nonConstantNamedComputationOps, nonConstFuncOps,
207+ symbolTable) &&
185208 !op->hasTrait <OpTrait::IsTerminator>()) {
186209 if (auto namedCompuationOp = op->getParentOfType <NamedComputationOp>()) {
187210 nonConstantNamedComputationOps.insert (namedCompuationOp.getName ());
188211 }
212+ nonConstFuncOps.insert (getOriginalFuncName (funcOp));
189213 }
190214 if (isScalarExpansion (op)) {
191215 scalarExpansionOps.insert (op);
192216 return ;
193217 }
194- cloneSubComputationOnOperands (op, constantOps, scalarExpansionOps);
218+ cloneSubComputationOnOperands (op, constantOps, scalarExpansionOps,
219+ symbolTable);
195220}
196221
197222// Converts stablehlo::ConstantOp to sdy::ConstantOp.
@@ -240,16 +265,16 @@ struct ConstantOrScalarSplitterPass
240265 }
241266 }
242267
243- // Assumes that the `NamedComputationOp` of the region are already walked, and
244- // skips walking on them.
245268 void walkOnRegion (
246- mlir::Region& region,
247- llvm::SmallDenseSet<StringRef>& nonConstantNamedComputationOps) {
269+ mlir::Region& region, FuncOp funcOp,
270+ llvm::SmallDenseSet<StringRef>& nonConstantNamedComputationOps,
271+ llvm::SmallDenseSet<StringRef>& nonConstFuncOps,
272+ SymbolTable& symbolTable) {
248273 llvm::SetVector<Operation*> constantOps;
249274 llvm::SetVector<Operation*> scalarExpansionOps;
250275 region.walk <WalkOrder::PreOrder>([&](Operation* op) {
251- processOp (op, constantOps, scalarExpansionOps,
252- nonConstantNamedComputationOps);
276+ processOp (op, funcOp, constantOps, scalarExpansionOps,
277+ nonConstantNamedComputationOps, nonConstFuncOps, symbolTable );
253278 // Skip walking on the `NamedComputationOp`.
254279 if (isa<NamedComputationOp>(op)) {
255280 return WalkResult::skip ();
@@ -267,6 +292,7 @@ struct ConstantOrScalarSplitterPass
267292
268293 void runOnOperation () final {
269294 ModuleOp moduleOp = getOperation ();
295+ SymbolTable symbolTable (moduleOp);
270296
271297 // We first convert any `stablehlo::ConstantOp` to an `sdy::ConstantOp`, so
272298 // that constants won't be deduped via folding.
@@ -276,15 +302,15 @@ struct ConstantOrScalarSplitterPass
276302
277303 // Then we split constant sub-computations for each non-constant user.
278304 llvm::SmallDenseSet<StringRef> nonConstantNamedComputationOps;
279- // Iterate on a post-order of NamedComputationOp blocks.
280- moduleOp. walk ( [&](NamedComputationOp namedComputationOp ) {
281- walkOnRegion (namedComputationOp. getBody (),
282- nonConstantNamedComputationOps);
283- });
284- // Iterate order does not matter. Funcs do not call each other. The calls
285- // are inlined to NamedComputationOps.
286- moduleOp. walk ([&](FuncOp funcOp) {
287- walkOnRegion (funcOp. getBody (), nonConstantNamedComputationOps );
305+ llvm::SmallDenseSet<StringRef> nonConstFuncOps;
306+ iterateFuncs (moduleOp, [&](FuncOp funcOp ) {
307+ funcOp. walk ([&](NamedComputationOp namedComputationOp) {
308+ walkOnRegion (namedComputationOp. getBody (), funcOp,
309+ nonConstantNamedComputationOps, nonConstFuncOps,
310+ symbolTable);
311+ });
312+ walkOnRegion (funcOp. getBody (), funcOp, nonConstantNamedComputationOps,
313+ nonConstFuncOps, symbolTable );
288314 });
289315 }
290316
0 commit comments