Skip to content

Commit b1dd5aa

Browse files
ekayaaslancopybara-github
authored andcommitted
Drop inlining on shardy.
It adds func data flow edges unconditionally. It propagates on FuncDataFlowEdges between its sources to users. If the operand of FuncDataFlowEdges is: - BlockArgument: The operand is a func argument. The source is the operand of the caller of the func. - OpResult: The opearnd is a call result. The source is the terminator of the called func. PiperOrigin-RevId: 900761579
1 parent 9e4d9b8 commit b1dd5aa

11 files changed

Lines changed: 390 additions & 83 deletions

shardy/dialect/sdy/transforms/import/add_data_flow_edges.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,10 @@ using func::FuncOp;
3838

3939
// Adds func input and output data flow edges. Adds func input data flow edges
4040
// only for non-main funcs.
41-
void addFuncDataFlowEdges(ModuleOp moduleOp, const SymbolTable& symbolTable,
42-
IRRewriter& rewriter) {
43-
FuncOp mainFuncOp = getMainFuncOrDie(moduleOp, symbolTable);
41+
void addFuncDataFlowEdgeOps(ModuleOp moduleOp, const SymbolTable& symbolTable,
42+
IRRewriter& rewriter) {
43+
FuncOp mainFuncOp =
44+
getMainFuncOrDie(moduleOp, symbolTable, /*useSingleFunc=*/true);
4445
moduleOp.walk([&](FuncOp funcOp) {
4546
if (funcOp == mainFuncOp) {
4647
return;
@@ -84,8 +85,8 @@ struct AddDataFlowEdgesPass
8485
addDataFlowEdges(op.getBlockArgumentEdgeOwners(), rewriter);
8586
addDataFlowEdges(op.getOpResultEdgeOwners(), rewriter);
8687
});
87-
if (enableNativeNonFlatSupport) {
88-
addFuncDataFlowEdges(moduleOp, symbolTable, rewriter);
88+
if (addFuncDataFlowEdges) {
89+
addFuncDataFlowEdgeOps(moduleOp, symbolTable, rewriter);
8990
}
9091
}
9192
};

shardy/dialect/sdy/transforms/import/import_pipeline.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ void addImportPipeline(OpPassManager& pm, int& dumpIndex,
4343
options.dumpDirectory, "before_propagation", dumpIndex++));
4444

4545
pm.addPass(createAddDataFlowEdgesPass(
46-
AddDataFlowEdgesPassOptions{options.enableNativeNonFlatSupport}));
46+
AddDataFlowEdgesPassOptions{/*addFuncDataFlowEdges=*/true}));
4747
pm.addPass(
4848
createApplyShardingConstraintsPass(ApplyShardingConstraintsPassOptions{
4949
options.debugShardingOrigins, options.debugPropagationEdgeSharding}));

shardy/dialect/sdy/transforms/import/passes.td

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,9 @@ def AddDataFlowEdgesPass : Pass<"sdy-add-data-flow-edges", "ModuleOp"> {
9292
let dependentDialects = ["mlir::sdy::SdyDialect"];
9393

9494
let options = [
95-
Option<"enableNativeNonFlatSupport", "enable-native-non-flat-support", "bool",
95+
Option<"addFuncDataFlowEdges", "add-func-data-flow-edges", "bool",
9696
/*default=*/"false",
97-
"Whether to propagate shardings directly on a non-flat graph without "
98-
"flattening it. The default is false, meaning it will flatten the "
99-
"graph and then propagate.">
97+
"Whether to add func data flow edges.">
10098
];
10199
}
102100

shardy/dialect/sdy/transforms/import/test/add_data_flow_edges_enable_native_non_flat_support.mlir renamed to shardy/dialect/sdy/transforms/import/test/add_data_flow_edges_add_func_data_flow_edges_true.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: sdy_opt %s -split-input-file -sdy-add-data-flow-edges='enable-native-non-flat-support=true' | FileCheck %s
1+
// RUN: sdy_opt %s -split-input-file -sdy-add-data-flow-edges='add-func-data-flow-edges=true' | FileCheck %s
22

33
// CHECK-LABEL: @bar(%arg0: tensor<8xf32>)
44
func.func @bar(%arg0: tensor<8xf32>) -> tensor<8xf32> {

shardy/dialect/sdy/transforms/import/test/import_pipeline.mlir

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ func.func @main(%arg0: tensor<32x96xf32>) -> tensor<32x96xf32> {
1717

1818
// -----
1919

20+
2021
// CHECK-LABEL: func @main
2122
func.func @main(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32>) {
2223
// CHECK-DAG: sdy.sharding_group %arg0 group_id=0 : tensor<8x8xf32>
@@ -30,10 +31,11 @@ func.func @main(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32>) {
3031

3132
// -----
3233

34+
3335
sdy.mesh @mesh = <["c"=2, "a"=2, "b"=2]>
3436

35-
// CHECK-LABEL: @add_manual_axes_to_replicated
36-
func.func @add_manual_axes_to_replicated(%arg0: tensor<8xf32>) -> tensor<8xf32> {
37+
// CHECK-LABEL: func @main
38+
func.func @main(%arg0: tensor<8xf32>) -> tensor<8xf32> {
3739
// CHECK-NEXT: sdy.manual_computation(%arg0)
3840
// CHECK-SAME{LITERAL}: in_shardings=[<@mesh, [{"c", ?}], replicated={"a"}>]
3941
// CHECK-SAME{LITERAL}: out_shardings=[<@mesh, [{"c", ?}], replicated={"a"}>]
@@ -46,14 +48,15 @@ func.func @add_manual_axes_to_replicated(%arg0: tensor<8xf32>) -> tensor<8xf32>
4648

4749
// -----
4850

51+
4952
sdy.mesh @mesh = <["c"=2, "a"=2, "b"=2]>
5053

5154
// Due to the in_sharding being fully closed, the in_sharding is added to the
5255
// func arg but with the manual axis added as replicated.
53-
// CHECK-LABEL: @add_manual_axes_to_replicated_applied_constraint
56+
// CHECK-LABEL: func @main
5457
// CHECK-SAME %arg0: tensor<16x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c"}, {}], replicated={"a"}>}
5558
// CHECK-SAME -> tensor<16x16xf32> {
56-
func.func @add_manual_axes_to_replicated_applied_constraint(%arg0: tensor<8xf32>) -> tensor<8xf32> {
59+
func.func @main(%arg0: tensor<8xf32>) -> tensor<8xf32> {
5760
%0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"c"}]>] out_shardings=[<@mesh, [{"c"}]>] manual_axes={"c", "a"} (%arg1: tensor<4xf32>) {
5861
%1 = stablehlo.add %arg1, %arg1 : tensor<4xf32>
5962
sdy.return %1 : tensor<4xf32>
@@ -63,27 +66,31 @@ func.func @add_manual_axes_to_replicated_applied_constraint(%arg0: tensor<8xf32>
6366

6467
// -----
6568

69+
6670
sdy.mesh @mesh = <["a"=2]>
6771

6872
// This test verifies that the manual axes are cleaned up before adding data
6973
// flow edges.
70-
func.func @manual_axes_cleanup_before_adding_data_flow_edges(%arg0: tensor<8xf32>) -> tensor<8xf32> {
74+
func.func @main(%arg0: tensor<8xf32>) -> tensor<8xf32> {
7175
%0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{?}]>] out_shardings=[<@mesh, [{?}]>] manual_axes={"a"} (%arg1: tensor<8xf32>) {
7276
%1 = stablehlo.add %arg1, %arg1 : tensor<8xf32>
7377
sdy.return %1 : tensor<8xf32>
7478
} : (tensor<8xf32>) -> tensor<8xf32>
75-
// CHECK: sdy.data_flow_edge %0 sharding=<@mesh, [{?}], replicated={"a"}> : tensor<8xf32>
79+
// CHECK: sdy.data_flow_edge %0 sharding=<@mesh, [{"c"}], replicated={"a"}> : tensor<8xf32>
7680
return %0 : tensor<8xf32>
7781
}
7882

7983
// -----
8084

85+
8186
sdy.mesh @mesh = <["a"=2]>
8287

83-
// CHECK-LABEL: func @single_call
84-
func.func @single_call(%arg0: tensor<8xf32>) -> tensor<8xf32> {
85-
// CHECK-NEXT: %0 = call @foo(%arg0) : (tensor<8xf32>) -> tensor<8xf32>
86-
// CHECK-NEXT: return %0 : tensor<8xf32>
88+
// test: single_call
89+
// CHECK-LABEL: func @main
90+
func.func @main(%arg0: tensor<8xf32>) -> tensor<8xf32> {
91+
// CHECK: %[[CALL:.*]] = call @foo(%arg0) : (tensor<8xf32>) -> tensor<8xf32>
92+
// CHECK: %[[EDGE:.*]] = sdy.func_data_flow_edge %[[CALL]] : tensor<8xf32>
93+
// CHECK: return %[[EDGE]] : tensor<8xf32>
8794
%0 = call @foo(%arg0) : (tensor<8xf32>) -> tensor<8xf32>
8895
return %0 : tensor<8xf32>
8996
}

shardy/dialect/sdy/transforms/propagation/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ cc_library(
5757
":sharding_projection",
5858
":utils",
5959
"//shardy/common:file_utils",
60+
"//shardy/common:logging",
6061
"//shardy/dialect/sdy/ir:dialect",
6162
"//shardy/dialect/sdy/transforms/common:op_properties",
6263
"//shardy/dialect/sdy/transforms/common:propagation_options",

0 commit comments

Comments
 (0)