Skip to content

Commit 9f4effd

Browse files
ekayaaslancopybara-github
authored andcommitted
Push shardy unflatenner up past RemoveAllGatherReduceScatterForCMV1 pass.
RemoveAllGatherReduceScatterForCMV1 removes all-gather and reduce-scatter ops on certain patterns. It does not interact with the call graph. PiperOrigin-RevId: 908228518
1 parent 85854db commit 9f4effd

1 file changed

Lines changed: 10 additions & 4 deletions

File tree

shardy/dialect/sdy/transforms/export/export_pipeline.cc

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ void runShardyPartitioner(OpPassManager& pm, int& dumpIndex,
5353
// during InsertExplicitReshards pass.
5454
}
5555
addCanonicalizerPass(pm, kCollectiveLabel);
56+
pm.addPass(createUnflattenCallGraphPass(
57+
UnflattenCallGraphPassOptions{options.dedupFunctionsFully}));
58+
// Keep a SymbolDCE after UnflattenCallGraph.
59+
pm.addPass(createSymbolDCEPass());
60+
5661
if (options.enableInsertExplicitCollectives &&
5762
options.removeAllGatherReduceScatterForCMV1) {
5863
pm.addNestedPass<func::FuncOp>(
@@ -95,11 +100,12 @@ void addExportPipeline(OpPassManager& pm, int& dumpIndex,
95100
// reshards/collectives.
96101
if (!options.avoidExportForPartitioning) {
97102
runShardyPartitioner(pm, dumpIndex, options);
103+
} else {
104+
pm.addPass(createUnflattenCallGraphPass(
105+
UnflattenCallGraphPassOptions{options.dedupFunctionsFully}));
106+
// Keep a SymbolDCE after UnflattenCallGraph.
107+
pm.addPass(createSymbolDCEPass());
98108
}
99-
pm.addPass(createUnflattenCallGraphPass(
100-
UnflattenCallGraphPassOptions{options.dedupFunctionsFully}));
101-
// Keep a SymbolDCE after UnflattenCallGraph.
102-
pm.addPass(createSymbolDCEPass());
103109
if (options.dumpPropagationEdges || options.dumpShardingOrigins) {
104110
pm.addPass(createRemovePropagationDebugInfoPass());
105111
}

0 commit comments

Comments
 (0)