Skip to content

Commit f76a02e

Browse files
ekayaaslancopybara-github
authored andcommitted
Push shardy flattener down past ConstantOrScalarSplitter pass.
ConstantOrScalarSplitter splits constant subcomputations including constant-preserving func calls. It is not sufficient to just split the calls and keep calling the same func because if two copies of the call gets different shardings, it implies a sharding conflict for at least one of them between the call and the func, and it implies a reshard to resolve func/call sharding conflicts. It would defy the purpose of constant splitting. Hence keep flatenning the funcs on the path of constant subcomputations for constant splitter. That is, given it gets a non-flat graph, it selectively flattens. PiperOrigin-RevId: 908358825
1 parent f507a17 commit f76a02e

2 files changed

Lines changed: 29 additions & 1 deletion

File tree

shardy/dialect/sdy/transforms/import/import_pipeline.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ void addImportPipeline(OpPassManager& pm, int& dumpIndex,
3131
pm.addPass(createLiftInlinedMeshesPass());
3232
pm.addPass(createRemoveSizeOneAxesPass());
3333
pm.addPass(createPropagateShardingFromFuncToCallPass());
34-
pm.addPass(createFlattenCallGraphPass());
3534
// Keep SymbolDCE after FlattenCallGraph.
3635
pm.addPass(createSymbolDCEPass());
3736
pm.addPass(createConstantOrScalarSplitterPass());
37+
pm.addPass(createFlattenCallGraphPass());
3838
pm.addPass(createSymbolDCEPass());
3939
pm.addPass(createManualAxesCleanupPass());
4040

shardy/dialect/sdy/transforms/import/test/constant_or_scalar_splitter.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1752,3 +1752,31 @@ func.func @split_constants_different_sharding(
17521752
%2 = stablehlo.add %1, %arg0 : tensor<8x8xf32>
17531753
return %0, %2 : tensor<8x16xf32>, tensor<8x8xf32>
17541754
}
1755+
1756+
// -----
1757+
1758+
// CHECK-LABEL: func @simple_non_flat
1759+
func.func @simple_non_flat() -> (tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32>) {
1760+
// CHECK-NEXT: %0 = call @foo_0()
1761+
// CHECK-NEXT: %1 = call @foo_1()
1762+
// CHECK-NEXT: %2 = stablehlo.abs %0
1763+
// CHECK-NEXT: %3 = stablehlo.abs %1
1764+
// CHECK-NEXT: %4 = call @foo_2()
1765+
// CHECK-NEXT: return %2, %3, %4
1766+
%0 = call @foo() : () -> (tensor<8x16xf32>)
1767+
%1 = stablehlo.abs %0 : tensor<8x16xf32>
1768+
%2 = stablehlo.abs %0 : tensor<8x16xf32>
1769+
%3 = call @foo() : () -> (tensor<8x16xf32>)
1770+
return %1, %2, %3 : tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32>
1771+
}
1772+
1773+
// CHECK-LABEL: func.func private @foo
1774+
func.func private @foo() -> tensor<8x16xf32> {
1775+
%0 = stablehlo.constant dense<1.000000e+00> : tensor<8x16xf32>
1776+
%1 = stablehlo.negate %0 : tensor<8x16xf32>
1777+
return %1 : tensor<8x16xf32>
1778+
}
1779+
1780+
// CHECK-LABEL: func.func private @foo_0() -> tensor<8x16xf32> attributes {sdy.original_func_name = "foo"} {
1781+
// CHECK-LABEL: func.func private @foo_1() -> tensor<8x16xf32> attributes {sdy.original_func_name = "foo"} {
1782+
// CHECK-LABEL: func.func private @foo_2() -> tensor<8x16xf32> attributes {sdy.original_func_name = "foo"} {

0 commit comments

Comments
 (0)