Skip to content

Commit a9bc412

Browse files
ekayaaslancopybara-github
authored andcommitted
Push shardy inliner down out of import pipeline for late inlining.
PiperOrigin-RevId: 900092715
1 parent 2bb09b4 commit a9bc412

13 files changed

+895
-118
lines changed

shardy/dialect/sdy/transforms/import/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ cc_library(
5858
"//shardy/dialect/sdy/transforms/common:sharding_walker",
5959
"//shardy/dialect/sdy/transforms/propagation:op_sharding_rule_registry",
6060
"//shardy/dialect/sdy/transforms/propagation:sharding_projection",
61+
"//shardy/dialect/sdy/transforms/propagation:utils",
6162
"//shardy/dialect/sdy/transforms/propagation/debugging:source_sharding",
6263
"@llvm-project//llvm:Support",
6364
"@llvm-project//mlir:Analysis",

shardy/dialect/sdy/transforms/import/add_data_flow_edges.cc

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License.
2121
#include "shardy/dialect/sdy/ir/dialect.h"
2222
#include "shardy/dialect/sdy/ir/utils.h"
2323
#include "shardy/dialect/sdy/transforms/import/passes.h" // IWYU pragma: keep
24+
#include "shardy/dialect/sdy/transforms/propagation/utils.h"
2425

2526
namespace mlir {
2627
namespace sdy {
@@ -34,22 +35,6 @@ struct AddDataFlowEdgesPass
3435
: public impl::AddDataFlowEdgesPassBase<AddDataFlowEdgesPass> {
3536
using AddDataFlowEdgesPassBase::AddDataFlowEdgesPassBase;
3637

37-
void addDataFlowEdges(ValueRange edgeOwners, IRRewriter& rewriter) {
38-
// We are iterating the owners in a reversed order because we set the
39-
// insertion point after each value and we would like to keep the data flow
40-
// edges for the arguments/results in the same order as they appear.
41-
for (Value edgeOwner : llvm::reverse(edgeOwners)) {
42-
rewriter.setInsertionPointAfterValue(edgeOwner);
43-
if (!isStaticShapedType(edgeOwner.getType())) {
44-
// Skip non-static-shaped tensors, e.g., tokens.
45-
continue;
46-
}
47-
auto dataFlowEdge = DataFlowEdgeOp::create(
48-
rewriter, edgeOwner.getLoc(), edgeOwner, getSharding(edgeOwner));
49-
rewriter.replaceAllUsesExcept(edgeOwner, dataFlowEdge, dataFlowEdge);
50-
}
51-
}
52-
5338
void runOnOperation() final {
5439
func::FuncOp funcOp = getOperation();
5540
IRRewriter rewriter(funcOp);

shardy/dialect/sdy/transforms/import/import_func_calls.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ limitations under the License.
3535
#include "shardy/dialect/sdy/ir/dialect.h"
3636
#include "shardy/dialect/sdy/ir/utils.h"
3737
#include "shardy/dialect/sdy/transforms/import/passes.h" // IWYU pragma: keep
38+
#include "shardy/dialect/sdy/transforms/propagation/utils.h"
3839

3940
namespace mlir {
4041
namespace sdy {
@@ -117,6 +118,18 @@ struct ImportFuncCallsPass
117118
for (auto [calleeName, _] : calleeNameToMovedRegion) {
118119
symbolTable.erase(symbolTable.lookup(calleeName));
119120
}
121+
122+
// Required for cases that AddDataFlowEdges runs before this pass.
123+
// TODO(enver): Drop after the late inlining drops ImportFuncCalls pass
124+
// altogether as long as early inlining is before AddDataFlowEdges pass.
125+
if (addDataFlowEdgesOnNamedComputations) {
126+
moduleOp.walk([&](NamedComputationOp namedComputationOp) {
127+
// Add the data flow edges for result owners and block argument owners.
128+
addDataFlowEdges(namedComputationOp.getBlockArgumentEdgeOwners(),
129+
rewriter);
130+
addDataFlowEdges(namedComputationOp.getOpResultEdgeOwners(), rewriter);
131+
});
132+
}
120133
}
121134
};
122135

shardy/dialect/sdy/transforms/import/import_pipeline.cc

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16+
#include "llvm/Support/CommandLine.h"
1617
#include "mlir/Dialect/Func/IR/FuncOps.h"
1718
#include "mlir/Pass/PassManager.h"
19+
#include "mlir/Pass/PassOptions.h"
1820
#include "mlir/Pass/PassRegistry.h"
1921
#include "mlir/Support/LLVM.h"
2022
#include "mlir/Transforms/Passes.h"
@@ -31,9 +33,12 @@ void addImportPipeline(OpPassManager& pm, int& dumpIndex,
3133
pm.addPass(createLiftInlinedMeshesPass());
3234
pm.addPass(createRemoveSizeOneAxesPass());
3335
pm.addPass(createPropagateShardingFromFuncToCallPass());
34-
pm.addPass(createImportFuncCallsPass());
35-
// Keep SymbolDCEPass after ImportFuncCallsPass.
36-
pm.addPass(createSymbolDCEPass());
36+
if (!options.enableLateInlining) {
37+
pm.addPass(createImportFuncCallsPass(ImportFuncCallsPassOptions{
38+
/*addDataFlowEdgesOnNamedComputations=*/false}));
39+
// Keep SymbolDCEPass after ImportFuncCallsPass.
40+
pm.addPass(createSymbolDCEPass());
41+
}
3742
pm.addPass(createConstantOrScalarSplitterPass());
3843
pm.addPass(createSymbolDCEPass());
3944
pm.addPass(createManualAxesCleanupPass());
@@ -61,14 +66,23 @@ void addImportPipeline(OpPassManager& pm, const PropagationOptions& options) {
6166
addImportPipeline(pm, dumpIndex, options);
6267
}
6368

69+
struct ImportPipelineOptions
70+
: public PassPipelineOptions<ImportPipelineOptions> {
71+
Option<bool> enableLateInlining{*this, "enable-late-inlining",
72+
llvm::cl::desc("Whether to late inline."),
73+
llvm::cl::init(true)};
74+
};
75+
6476
void registerImportPipeline() {
65-
PassPipelineRegistration<>(
77+
PassPipelineRegistration<ImportPipelineOptions>(
6678
"sdy-import-pipeline",
6779
"Run a sequence of import passes needed as a pre-processing step for "
6880
"Shardy propagation",
69-
[](OpPassManager& pm) {
81+
[](OpPassManager& pm, const ImportPipelineOptions& options) {
7082
int dumpIndex = 0;
71-
addImportPipeline(pm, dumpIndex, PropagationOptions());
83+
addImportPipeline(pm, dumpIndex,
84+
PropagationOptions{.enableLateInlining =
85+
options.enableLateInlining});
7286
});
7387
}
7488

shardy/dialect/sdy/transforms/import/passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ def ImportFuncCallsPass : Pass<"sdy-import-func-calls", "ModuleOp"> {
2626
function body for each call op and emit a warning.
2727
}];
2828
let dependentDialects = ["mlir::sdy::SdyDialect"];
29+
let options = [
30+
Option<"addDataFlowEdgesOnNamedComputations", "add-data-flow-edges-on-named-computations", "bool",
31+
/*default=*/"true",
32+
"Whether to add data flow edges on named computations.">
33+
];
2934
}
3035

3136
def PropagateShardingFromFuncToCallPass : Pass<"sdy-propagate-sharding-from-func-to-call", "ModuleOp"> {

0 commit comments

Comments
 (0)