@@ -42,24 +42,19 @@ void runShardyPartitioner(OpPassManager& pm, int& dumpIndex,
4242 passOptions.enableFullVersion = options.enableInsertExplicitCollectives ;
4343 passOptions.avoidReshardsOnCalls = options.avoidReshardsOnCalls ;
4444 pm.addNestedPass <func::FuncOp>(createInsertExplicitReshardsPass (passOptions));
45+ pm.addPass (createUnflattenCallGraphPass (
46+ UnflattenCallGraphPassOptions{options.dedupFunctionsFully }));
47+ // Keep a SymbolDCE after UnflattenCallGraph.
48+ pm.addPass (createSymbolDCEPass ());
4549 if (options.enableInsertExplicitCollectives ) {
4650 pm.addPass (mlir::sdy::createSaveModuleOpPass (
4751 options.dumpDirectory , " after_explicit_reshards" , dumpIndex++));
4852 addCanonicalizerPass (pm, kReshardLabel );
49- pm.addPass (createUnflattenCallGraphPass (
50- UnflattenCallGraphPassOptions{options.dedupFunctionsFully }));
51- // Keep a SymbolDCE after UnflattenCallGraph.
52- pm.addPass (createSymbolDCEPass ());
5353 pm.addNestedPass <func::FuncOp>(createReshardToCollectivesPass ());
5454 // NOTE: ReshardToCollectives pass above generates all-slice collectives,
5555 // which during the canonicalizer below may be converted to reduce scatters
5656 // by potentially fusing with preceeding all-reduces, which are inserted
5757 // during InsertExplicitReshards pass.
58- } else {
59- pm.addPass (createUnflattenCallGraphPass (
60- UnflattenCallGraphPassOptions{options.dedupFunctionsFully }));
61- // Keep a SymbolDCE after UnflattenCallGraph.
62- pm.addPass (createSymbolDCEPass ());
6358 }
6459
6560 addCanonicalizerPass (pm, kCollectiveLabel );
0 commit comments