Skip to content

Commit 883fc34

Browse files
ekayaaslancopybara-github
authored andcommitted
Push shardy unflatenner up past Canonicalizer on reshards.
PiperOrigin-RevId: 908449560
1 parent 03f2b7b commit 883fc34

1 file changed

Lines changed: 4 additions & 9 deletions

File tree

shardy/dialect/sdy/transforms/export/export_pipeline.cc

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,24 +42,19 @@ void runShardyPartitioner(OpPassManager& pm, int& dumpIndex,
4242
passOptions.enableFullVersion = options.enableInsertExplicitCollectives;
4343
passOptions.avoidReshardsOnCalls = options.avoidReshardsOnCalls;
4444
pm.addNestedPass<func::FuncOp>(createInsertExplicitReshardsPass(passOptions));
45+
pm.addPass(createUnflattenCallGraphPass(
46+
UnflattenCallGraphPassOptions{options.dedupFunctionsFully}));
47+
// Keep a SymbolDCE after UnflattenCallGraph.
48+
pm.addPass(createSymbolDCEPass());
4549
if (options.enableInsertExplicitCollectives) {
4650
pm.addPass(mlir::sdy::createSaveModuleOpPass(
4751
options.dumpDirectory, "after_explicit_reshards", dumpIndex++));
4852
addCanonicalizerPass(pm, kReshardLabel);
49-
pm.addPass(createUnflattenCallGraphPass(
50-
UnflattenCallGraphPassOptions{options.dedupFunctionsFully}));
51-
// Keep a SymbolDCE after UnflattenCallGraph.
52-
pm.addPass(createSymbolDCEPass());
5353
pm.addNestedPass<func::FuncOp>(createReshardToCollectivesPass());
5454
// NOTE: ReshardToCollectives pass above generates all-slice collectives,
5555
// which during the canonicalizer below may be converted to reduce scatters
5656
// by potentially fusing with preceeding all-reduces, which are inserted
5757
// during InsertExplicitReshards pass.
58-
} else {
59-
pm.addPass(createUnflattenCallGraphPass(
60-
UnflattenCallGraphPassOptions{options.dedupFunctionsFully}));
61-
// Keep a SymbolDCE after UnflattenCallGraph.
62-
pm.addPass(createSymbolDCEPass());
6358
}
6459

6560
addCanonicalizerPass(pm, kCollectiveLabel);

0 commit comments

Comments
 (0)