Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 60 additions & 34 deletions shardy/dialect/sdy/transforms/import/constant_or_scalar_splitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
Expand All @@ -48,6 +49,7 @@ namespace sdy {

namespace {

using func::CallOp;
using func::FuncOp;

void cloneShardingGroupUsers(OpResult opResult, IRMapping& mapping,
Expand All @@ -66,10 +68,16 @@ void cloneShardingGroupUsers(OpResult opResult, IRMapping& mapping,
// - A broadcast, reshape or slice op.
// - An elementwise op.
// - A named computation all operations are constant preserving.
// - A call to a func that all operations are constant preserving.
// Assumes the op is not constant or iota.
bool isConstantPreserving(
Operation* op,
const llvm::SmallDenseSet<StringRef>& nonConstantNamedComputationOps) {
const llvm::SmallDenseSet<StringRef>& nonConstantNamedComputationOps,
const llvm::SmallDenseSet<StringRef>& nonConstFuncOps,
const SymbolTable& symbolTable) {
if (CallOp callOp = dyn_cast<CallOp>(op)) {
return !nonConstFuncOps.contains(getOriginalFuncName(callOp, symbolTable));
}
if (auto namedComputationOp = dyn_cast<NamedComputationOp>(op)) {
return !nonConstantNamedComputationOps.contains(
namedComputationOp.getName());
Expand All @@ -93,11 +101,14 @@ bool isConstantPreserving(
// constants, that is, exist in `constantOps`.
bool isConstantExpression(
Operation* op, const llvm::SetVector<Operation*>& constantOps,
const llvm::SmallDenseSet<StringRef>& nonConstantNamedComputationOps) {
const llvm::SmallDenseSet<StringRef>& nonConstantNamedComputationOps,
const llvm::SmallDenseSet<StringRef>& nonConstFuncOps,
const SymbolTable& symbolTable) {
if (isa<ConstantOp, stablehlo::IotaOp>(op)) {
return true;
}
return isConstantPreserving(op, nonConstantNamedComputationOps) &&
return isConstantPreserving(op, nonConstantNamedComputationOps,
nonConstFuncOps, symbolTable) &&
llvm::all_of(op->getOperands(), [&](Value operand) {
return operand.getDefiningOp() &&
constantOps.contains(operand.getDefiningOp());
Expand All @@ -117,20 +128,26 @@ bool isScalarExpansion(Operation* op) {
// Recursively clones all operands of the given op, that are not already mapped
// in `mapping`, and finally clones the op itself. We do not clone scalars as
// they do not get sharded.
void cloneSubComputation(OpResult opResult, IRMapping& mapping) {
void cloneSubComputation(OpResult opResult, IRMapping& mapping,
SymbolTable& symbolTable) {
if (isScalar(opResult) || mapping.lookupOrNull(opResult)) {
return;
}
Operation* op = opResult.getOwner();
for (Value operand : op->getOperands()) {
if (auto defOpResult = dyn_cast<OpResult>(operand)) {
cloneSubComputation(defOpResult, mapping);
cloneSubComputation(defOpResult, mapping, symbolTable);
}
}

// This will insert the cloned op right before the original op.
OpBuilder builder(op);
builder.clone(*op, mapping);
Operation* clonedOp = builder.clone(*op, mapping);
if (CallOp callOp = dyn_cast<CallOp>(clonedOp)) {
FuncOp funcOp = symbolTable.lookup<FuncOp>(callOp.getCallee());
callOp.setCallee(
symbolTable.insert(cloneFuncRecursively(funcOp, symbolTable)));
}
cloneShardingGroupUsers(opResult, mapping, builder);
}

Expand All @@ -139,18 +156,19 @@ void cloneSubComputation(OpResult opResult, IRMapping& mapping) {
// sharded.
//
// Returns the cloned op result.
Value cloneSubComputation(OpResult opResult) {
Value cloneSubComputation(OpResult opResult, SymbolTable& symbolTable) {
if (isScalar(opResult)) {
return opResult;
}
IRMapping mapping;
cloneSubComputation(opResult, mapping);
cloneSubComputation(opResult, mapping, symbolTable);
return mapping.lookup(opResult);
}

void cloneSubComputationOnOperands(
Operation* op, const llvm::SetVector<Operation*>& constantOps,
const llvm::SetVector<Operation*>& scalarExpansionOps) {
const llvm::SetVector<Operation*>& scalarExpansionOps,
SymbolTable& symbolTable) {
for (OpOperand& operand : op->getOpOperands()) {
if (auto defOpResult = dyn_cast<OpResult>(operand.get());
defOpResult && (constantOps.contains(defOpResult.getOwner()) ||
Expand All @@ -160,38 +178,45 @@ void cloneSubComputationOnOperands(
// `defOpResult`, and replace the `operand` with the cloned defining
// op. The cloned constant sub-computation has only one user `op`,
// so that it is isolated from the rest of the computation.
operand.set(cloneSubComputation(defOpResult));
operand.set(cloneSubComputation(defOpResult, symbolTable));
}
}
}

void processOp(Operation* op, llvm::SetVector<Operation*>& constantOps,
void processOp(Operation* op, FuncOp funcOp,
llvm::SetVector<Operation*>& constantOps,
llvm::SetVector<Operation*>& scalarExpansionOps,
llvm::SmallDenseSet<StringRef>& nonConstantNamedComputationOps) {
if (isa<ShardingGroupOp>(op)) {
llvm::SmallDenseSet<StringRef>& nonConstantNamedComputationOps,
llvm::SmallDenseSet<StringRef>& nonConstFuncOps,
SymbolTable& symbolTable) {
if (isa<FuncOp, ShardingGroupOp>(op)) {
return;
}
if (isConstantExpression(op, constantOps, nonConstantNamedComputationOps)) {
if (isConstantExpression(op, constantOps, nonConstantNamedComputationOps,
nonConstFuncOps, symbolTable)) {
constantOps.insert(op);
return;
}
// NOTE: There are cases that op is an constant expression but may not pass
// the following check such as constant and iota ops. That is fine because if
// the op is a constant expression it is a stronger condition than being just
// constant preserving and it does not make the parent named computation
// non-const, and at this point, it is guaranteed that the op is not constant
// expression.
if (!isConstantPreserving(op, nonConstantNamedComputationOps) &&
// constant preserving and it does not make the parent named computation or
// the `funcOp` non-const, and at this point, it is guaranteed that the op is
// not constant expression.
if (!isConstantPreserving(op, nonConstantNamedComputationOps, nonConstFuncOps,
symbolTable) &&
!op->hasTrait<OpTrait::IsTerminator>()) {
if (auto namedCompuationOp = op->getParentOfType<NamedComputationOp>()) {
nonConstantNamedComputationOps.insert(namedCompuationOp.getName());
}
nonConstFuncOps.insert(getOriginalFuncName(funcOp));
}
if (isScalarExpansion(op)) {
scalarExpansionOps.insert(op);
return;
}
cloneSubComputationOnOperands(op, constantOps, scalarExpansionOps);
cloneSubComputationOnOperands(op, constantOps, scalarExpansionOps,
symbolTable);
}

// Converts stablehlo::ConstantOp to sdy::ConstantOp.
Expand Down Expand Up @@ -240,16 +265,16 @@ struct ConstantOrScalarSplitterPass
}
}

// Assumes that the `NamedComputationOp` of the region are already walked, and
// skips walking on them.
void walkOnRegion(
mlir::Region& region,
llvm::SmallDenseSet<StringRef>& nonConstantNamedComputationOps) {
mlir::Region& region, FuncOp funcOp,
llvm::SmallDenseSet<StringRef>& nonConstantNamedComputationOps,
llvm::SmallDenseSet<StringRef>& nonConstFuncOps,
SymbolTable& symbolTable) {
llvm::SetVector<Operation*> constantOps;
llvm::SetVector<Operation*> scalarExpansionOps;
region.walk<WalkOrder::PreOrder>([&](Operation* op) {
processOp(op, constantOps, scalarExpansionOps,
nonConstantNamedComputationOps);
processOp(op, funcOp, constantOps, scalarExpansionOps,
nonConstantNamedComputationOps, nonConstFuncOps, symbolTable);
// Skip walking on the `NamedComputationOp`.
if (isa<NamedComputationOp>(op)) {
return WalkResult::skip();
Expand All @@ -267,6 +292,7 @@ struct ConstantOrScalarSplitterPass

void runOnOperation() final {
ModuleOp moduleOp = getOperation();
SymbolTable symbolTable(moduleOp);

// We first convert any `stablehlo::ConstantOp` to an `sdy::ConstantOp`, so
// that constants won't be deduped via folding.
Expand All @@ -276,15 +302,15 @@ struct ConstantOrScalarSplitterPass

// Then we split constant sub-computations for each non-constant user.
llvm::SmallDenseSet<StringRef> nonConstantNamedComputationOps;
// Iterate on a post-order of NamedComputationOp blocks.
moduleOp.walk([&](NamedComputationOp namedComputationOp) {
walkOnRegion(namedComputationOp.getBody(),
nonConstantNamedComputationOps);
});
// Iterate order does not matter. Funcs do not call each other. The calls
// are inlined to NamedComputationOps.
moduleOp.walk([&](FuncOp funcOp) {
walkOnRegion(funcOp.getBody(), nonConstantNamedComputationOps);
llvm::SmallDenseSet<StringRef> nonConstFuncOps;
iterateFuncs(moduleOp, [&](FuncOp funcOp) {
funcOp.walk([&](NamedComputationOp namedComputationOp) {
walkOnRegion(namedComputationOp.getBody(), funcOp,
nonConstantNamedComputationOps, nonConstFuncOps,
symbolTable);
});
walkOnRegion(funcOp.getBody(), funcOp, nonConstantNamedComputationOps,
nonConstFuncOps, symbolTable);
});
}

Expand Down
28 changes: 22 additions & 6 deletions shardy/dialect/sdy/transforms/import/import_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,19 @@ void addImportPipeline(OpPassManager& pm, int& dumpIndex,
pm.addPass(createLiftInlinedMeshesPass());
pm.addPass(createRemoveSizeOneAxesPass());
pm.addPass(createPropagateShardingFromFuncToCallPass());
pm.addPass(createImportFuncCallsPass());
// Keep SymbolDCEPass after ImportFuncCallsPass.
pm.addPass(createSymbolDCEPass());
if (!options.enableLateInlining) {
pm.addPass(createImportFuncCallsPass());
// Keep SymbolDCEPass after ImportFuncCallsPass.
pm.addPass(createSymbolDCEPass());
}
pm.addPass(createConstantOrScalarSplitterPass());
pm.addPass(createSymbolDCEPass());
pm.addPass(createManualAxesCleanupPass());
if (options.enableLateInlining) {
pm.addPass(createImportFuncCallsPass());
// Keep SymbolDCEPass after ImportFuncCallsPass.
pm.addPass(createSymbolDCEPass());
}

// We dump the module before propagation at this point, since the import
// passes before are cleanup passes that make the module more readable, and
Expand All @@ -61,14 +68,23 @@ void addImportPipeline(OpPassManager& pm, const PropagationOptions& options) {
addImportPipeline(pm, dumpIndex, options);
}

struct ImportPipelineOptions
: public PassPipelineOptions<ImportPipelineOptions> {
Option<bool> enableLateInlining{*this, "enable-late-inlining",
llvm::cl::desc("Whether to late inline."),
llvm::cl::init(true)};
};

void registerImportPipeline() {
PassPipelineRegistration<>(
PassPipelineRegistration<ImportPipelineOptions>(
"sdy-import-pipeline",
"Run a sequence of import passes needed as a pre-processing step for "
"Shardy propagation",
[](OpPassManager& pm) {
[](OpPassManager& pm, const ImportPipelineOptions& options) {
int dumpIndex = 0;
addImportPipeline(pm, dumpIndex, PropagationOptions());
addImportPipeline(pm, dumpIndex,
PropagationOptions{.enableLateInlining =
options.enableLateInlining});
});
}

Expand Down
Loading
Loading