diff --git a/shardy/dialect/sdy/transforms/import/constant_or_scalar_splitter.cc b/shardy/dialect/sdy/transforms/import/constant_or_scalar_splitter.cc index 69111d254..95e61c2b7 100644 --- a/shardy/dialect/sdy/transforms/import/constant_or_scalar_splitter.cc +++ b/shardy/dialect/sdy/transforms/import/constant_or_scalar_splitter.cc @@ -21,9 +21,11 @@ limitations under the License. #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" +#include "mlir/IR/Visitors.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Support/LLVM.h" @@ -124,6 +126,39 @@ Value cloneSubComputation(OpResult opResult) { return mapping.lookup(opResult); } +void cloneSubComputationOnOperands( + Operation* op, const llvm::SetVector& constantOps, + const llvm::SetVector& scalarExpansionOps) { + for (OpOperand& operand : op->getOpOperands()) { + if (auto defOpResult = dyn_cast(operand.get()); + defOpResult && (constantOps.contains(defOpResult.getOwner()) || + scalarExpansionOps.contains(defOpResult.getOwner()))) { + // `op` is not a constant expression, while its `operand` is. We + // recursively clone the sub-computation whose root is + // `defOpResult`, and replace the `operand` with the cloned defining + // op. The cloned constant sub-computation has only one user `op`, + // so that it is isolated from the rest of the computation. + operand.set(cloneSubComputation(defOpResult)); + } + } +} + +void processOp(Operation* op, llvm::SetVector& constantOps, + llvm::SetVector& scalarExpansionOps) { + if (isa(op)) { + return; + } + if (isConstantExpression(op, constantOps)) { + constantOps.insert(op); + return; + } + if (isScalarExpansion(op)) { + scalarExpansionOps.insert(op); + return; + } + cloneSubComputationOnOperands(op, constantOps, scalarExpansionOps); +} + // Converts stablehlo::ConstantOp to sdy::ConstantOp. class ConstantPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -167,32 +202,8 @@ struct ConstantOrScalarSplitterPass // Then we split constant sub-computations for each non-constant user. llvm::SetVector constantOps, scalarExpansionOps; - funcOp.walk([&](Operation* op) { - if (isa(op)) { - return; - } - if (isConstantExpression(op, constantOps)) { - constantOps.insert(op); - return; - } - if (isScalarExpansion(op)) { - scalarExpansionOps.insert(op); - return; - } - for (OpOperand& operand : op->getOpOperands()) { - if (auto defOpResult = dyn_cast(operand.get()); - defOpResult && - (constantOps.contains(defOpResult.getOwner()) || - scalarExpansionOps.contains(defOpResult.getOwner()))) { - // `op` is not a constant expression, while its `operand` is. We - // recursively clone the sub-computation whose root is - // `defOpResult`, and replace the `operand` with the cloned defining - // op. The cloned constant sub-computation has only one user `op`, - // so that it is isolated from the rest of the computation. - operand.set(cloneSubComputation(defOpResult)); - } - } - }); + funcOp.walk( + [&](Operation* op) { processOp(op, constantOps, scalarExpansionOps); }); // Since for every op in `constantOps` that has a use that isn't in // `constantOps`, we replaced the use with a clone of the entire