Skip to content

Commit 16eab08

Browse files
ekayaaslancopybara-github
authored andcommitted
Push shardy flattener down past AddDataFlowEdges pass.
AddDataFlowEdges adds func data flow edges for func arguments and call results. Having a non-flat call graph does not impact the behavior. PiperOrigin-RevId: 908139938
1 parent 03f2b7b commit 16eab08

2 files changed

Lines changed: 207 additions & 71 deletions

File tree

shardy/dialect/sdy/transforms/import/import_pipeline.cc

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ void addImportPipeline(OpPassManager& pm, int& dumpIndex,
3131
pm.addPass(createLiftInlinedMeshesPass());
3232
pm.addPass(createRemoveSizeOneAxesPass());
3333
pm.addPass(createPropagateShardingFromFuncToCallPass());
34-
// Keep SymbolDCE after FlattenCallGraph.
35-
pm.addPass(createSymbolDCEPass());
3634
pm.addPass(createConstantOrScalarSplitterPass());
3735
pm.addPass(createSymbolDCEPass());
3836
pm.addPass(createManualAxesCleanupPass());
@@ -44,11 +42,11 @@ void addImportPipeline(OpPassManager& pm, int& dumpIndex,
4442
pm.addPass(mlir::sdy::createSaveModuleOpPass(
4543
options.dumpDirectory, "before_propagation", dumpIndex++));
4644

45+
pm.addPass(createAddDataFlowEdgesPass(
46+
AddDataFlowEdgesPassOptions{options.enableNativeNonFlatSupport}));
4747
pm.addPass(createFlattenCallGraphPass());
4848
// Keep SymbolDCE after FlattenCallGraph.
4949
pm.addPass(createSymbolDCEPass());
50-
pm.addPass(createAddDataFlowEdgesPass(
51-
AddDataFlowEdgesPassOptions{options.enableNativeNonFlatSupport}));
5250
pm.addPass(
5351
createApplyShardingConstraintsPass(ApplyShardingConstraintsPassOptions{
5452
options.debugShardingOrigins, options.debugPropagationEdgeSharding}));

0 commit comments

Comments
 (0)