Skip to content

Commit 7cd9098

Browse files
ekayaaslancopybara-github
authored andcommitted
Sink func data flow edges for funcs and calls.
Drop enableNativeNonFlatSupport flag for sinking. It does not impact prod as the prod does not contain func data flow edges yet. It is behind flag enableNativeNonFlatSupport. PiperOrigin-RevId: 902646126
1 parent 1534841 commit 7cd9098

7 files changed

Lines changed: 218 additions & 133 deletions

File tree

shardy/dialect/sdy/transforms/export/export_pipeline.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,10 @@ void addExportPipeline(OpPassManager& pm, int& dumpIndex,
7676
pm.addPass(createRemoveShardingGroupsPass());
7777
pm.addNestedPass<func::FuncOp>(createShardingConstraintToReshardPass());
7878
}
79-
pm.addNestedPass<
80-
func::FuncOp>(createSinkDataFlowEdgesPass(SinkDataFlowEdgesPassOptions{
81-
/*sinkDebugShardingOrigins=*/options.dumpShardingOrigins,
82-
/*sinkDebugPropagationEdgeSharding=*/options.dumpPropagationEdges,
83-
/*sinkEnableNativeNonFlatSupport=*/options.enableNativeNonFlatSupport}));
79+
pm.addNestedPass<func::FuncOp>(
80+
createSinkDataFlowEdgesPass(SinkDataFlowEdgesPassOptions{
81+
/*sinkDebugShardingOrigins=*/options.dumpShardingOrigins,
82+
/*sinkDebugPropagationEdgeSharding=*/options.dumpPropagationEdges}));
8483
if (options.updateNonDivisibleInputOutputShardings) {
8584
pm.addPass(createUpdateNonDivisibleInputOutputShardingsPass());
8685
pm.addPass(createRemoveSubAxesInInputOutputShardingsPass());

shardy/dialect/sdy/transforms/export/passes.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,6 @@ struct ExportOptions : public PassPipelineOptions<ExportOptions> {
8686
*this, "update-non-divisible-input-output-shardings",
8787
llvm::cl::desc("Update axes with non-divisible input/output shardings."),
8888
llvm::cl::init(true)};
89-
90-
Option<bool> enableNativeNonFlatSupport{
91-
*this, "enable-native-non-flat-support",
92-
llvm::cl::desc("Whether to propagate shardings directly on a non-flat "
93-
"graph without flattening it. The default is false, "
94-
"meaning it will flatten the graph and then propagate."),
95-
llvm::cl::init(false)};
9689
};
9790

9891
// Adds a sequence of export passes needed as a post-processing step for SDY

shardy/dialect/sdy/transforms/export/passes.td

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,6 @@ def SinkDataFlowEdgesPass : Pass<"sdy-sink-data-flow-edges", "func::FuncOp"> {
4040
"Whether to sink the debug propagation edge sharding info. See "
4141
"`debug-propagation-edge-sharding` option in propagation for more "
4242
"info.">,
43-
Option<"enableNativeNonFlatSupport", "enable-native-non-flat-support", "bool",
44-
/*default=*/"false",
45-
"Whether to propagate shardings directly on a non-flat graph without "
46-
"flattening it. The default is false, meaning it will flatten the "
47-
"graph and then propagate.">
4843
];
4944
}
5045

shardy/dialect/sdy/transforms/export/sink_data_flow_edges.cc

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,21 +64,23 @@ struct SinkDataFlowEdgesPass
6464
if (isa<DataFlowEdgeOp>(op)) {
6565
DataFlowEdgeOp dataFlowEdgeOp = cast<DataFlowEdgeOp>(op);
6666
Value input = dataFlowEdgeOp.getInput();
67-
// TODO(enver): Drop enableNativeNonFlatSupport check and assume func
68-
// arguments do not have data flow edges in the first place.
69-
if (enableNativeNonFlatSupport) {
70-
if (func::FuncOp funcOp =
71-
dyn_cast<func::FuncOp>(getOwningOp(input))) {
72-
if (TensorShardingAttr sharding = dataFlowEdgeOp.getShardingAttr();
73-
sharding) {
74-
funcOp.setArgAttr(cast<BlockArgument>(input).getArgNumber(),
75-
kShardingAttr, sharding);
76-
}
77-
}
78-
}
7967
rewriter.replaceOp(dataFlowEdgeOp, input);
8068
return WalkResult::skip();
8169
}
70+
if (isa<FuncDataFlowEdgeOp>(op)) {
71+
FuncDataFlowEdgeOp funcEdgeOp = cast<FuncDataFlowEdgeOp>(op);
72+
Value operand = funcEdgeOp.getOperand();
73+
Value result = funcEdgeOp.getResult();
74+
TensorShardingAttr operandSharding = getSharding(operand);
75+
if (TensorShardingAttr sharding = getSharding(result)) {
76+
setSharding(operand, sharding);
77+
} else if (operandSharding) {
78+
setSharding(operand,
79+
TensorShardingAttr::getFullyOpenLike(operandSharding));
80+
}
81+
rewriter.replaceOp(funcEdgeOp, operand);
82+
return WalkResult::skip();
83+
}
8284
auto shardableDataFlowOp = dyn_cast<ShardableDataFlowOpInterface>(op);
8385
if (!shardableDataFlowOp) {
8486
return WalkResult::advance();

shardy/dialect/sdy/transforms/export/test/sink_data_flow_edges.mlir

Lines changed: 200 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: sdy_opt %s -sdy-sink-data-flow-edges | FileCheck %s
1+
// RUN: sdy_opt %s -split-input-file -sdy-sink-data-flow-edges | FileCheck %s
22

33
sdy.mesh @mesh = <["a"=2, "b"=2, "c"=2]>
44
sdy.mesh @other_mesh = <["c"=4]>
@@ -249,3 +249,202 @@ func.func @manual_computation_origin_debug_info(%arg0: tensor<32x32x32xf32>) ->
249249
%2 = sdy.data_flow_edge %1 sharding=<@mesh, [{"a", ?}, {"b", ?}, {?}]> {sdy.origin_sharding = {a = "mc_0_input: 0", b = "mc_0_output: 0"}} : tensor<32x32x32xf32>
250250
return %2 : tensor<32x32x32xf32>
251251
}
252+
253+
// -----
254+
255+
// CHECK-LABEL: func private @bar(%arg0: tensor<8xf32>)
256+
func.func private @bar(%arg0: tensor<8xf32>) -> tensor<8xf32> {
257+
// CHECK-NEXT: %[[NEGATE:.*]] = stablehlo.negate %arg0
258+
// CHECK-NEXT: return %[[NEGATE]]
259+
%0 = sdy.func_data_flow_edge %arg0 : tensor<8xf32>
260+
%1 = stablehlo.negate %0: tensor<8xf32>
261+
return %1 : tensor<8xf32>
262+
}
263+
264+
// CHECK-LABEL: func @simple_call_graph_on_func_with_single_argument(%arg0: tensor<8xf32>)
265+
func.func @simple_call_graph_on_func_with_single_argument(%arg0: tensor<8xf32>) -> tensor<8xf32> {
266+
// CHECK-NEXT: %[[ABS0:.*]] = stablehlo.abs %arg0
267+
// CHECK-NEXT: %[[CALL:.*]] = call @bar(%[[ABS0]])
268+
// CHECK-NEXT: %[[ABS1:.*]] = stablehlo.abs %[[CALL]]
269+
// CHECK-NEXT: return %[[ABS1]]
270+
%0 = stablehlo.abs %arg0 : tensor<8xf32>
271+
%1 = call @bar(%0) : (tensor<8xf32>) -> (tensor<8xf32>)
272+
%2 = sdy.func_data_flow_edge %1 : tensor<8xf32>
273+
%3 = stablehlo.abs %2 : tensor<8xf32>
274+
return %3 : tensor<8xf32>
275+
}
276+
277+
// -----
278+
279+
// CHECK-LABEL: @bar(%arg0: tensor<8xf32>)
280+
func.func private @bar(%arg0: tensor<8xf32>) -> tensor<8xf32> {
281+
// CHECK-NEXT: %[[NEGATE:.*]] = stablehlo.negate %arg0
282+
// CHECK-NEXT: return %[[NEGATE]]
283+
%0 = sdy.func_data_flow_edge %arg0 : tensor<8xf32>
284+
%1 = stablehlo.negate %0: tensor<8xf32>
285+
return %1 : tensor<8xf32>
286+
}
287+
288+
// CHECK-LABEL: @multiple_calls_on_same_func(%arg0: tensor<8xf32>)
289+
func.func @multiple_calls_on_same_func(%arg0: tensor<8xf32>) -> tensor<8xf32> {
290+
// CHECK-NEXT: %[[ABS0:.*]] = stablehlo.abs %arg0
291+
// CHECK-NEXT: %[[CALL0:.*]] = call @bar(%[[ABS0]])
292+
// CHECK-NEXT: %[[ABS1:.*]] = stablehlo.abs %[[CALL0]]
293+
// CHECK-NEXT: %[[CALL1:.*]] = call @bar(%[[ABS1]])
294+
// CHECK-NEXT: %[[ABS2:.*]] = stablehlo.abs %[[CALL1]]
295+
// CHECK-NEXT: return %[[ABS2]]
296+
%0 = stablehlo.abs %arg0 : tensor<8xf32>
297+
%1 = call @bar(%0) : (tensor<8xf32>) -> (tensor<8xf32>)
298+
%2 = sdy.func_data_flow_edge %1 : tensor<8xf32>
299+
%3 = stablehlo.abs %2 : tensor<8xf32>
300+
%4 = call @bar(%3) : (tensor<8xf32>) -> (tensor<8xf32>)
301+
%5 = sdy.func_data_flow_edge %4 : tensor<8xf32>
302+
%6 = stablehlo.abs %5 : tensor<8xf32>
303+
return %6 : tensor<8xf32>
304+
}
305+
306+
// -----
307+
308+
// CHECK-LABEL: @bar(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>)
309+
func.func private @bar(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<8xf32> {
310+
// CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %arg1
311+
// CHECK-NEXT: return %[[ADD]]
312+
%0 = sdy.func_data_flow_edge %arg0 : tensor<8xf32>
313+
%1 = sdy.func_data_flow_edge %arg1 : tensor<8xf32>
314+
%2 = stablehlo.add %0, %1: tensor<8xf32>
315+
return %2 : tensor<8xf32>
316+
}
317+
318+
// CHECK-LABEL: @simple_call_graph_on_func_with_multiple_argument(%arg0: tensor<8xf32>)
319+
func.func @simple_call_graph_on_func_with_multiple_argument(%arg0: tensor<8xf32>) -> tensor<8xf32> {
320+
// CHECK-NEXT: %[[ABS0:.*]] = stablehlo.abs %arg0
321+
// CHECK-NEXT: %[[ABS1:.*]] = stablehlo.abs %arg0
322+
// CHECK-NEXT: %[[CALL:.*]] = call @bar(%[[ABS0]], %[[ABS1]])
323+
// CHECK-NEXT: %[[ABS2:.*]] = stablehlo.abs %[[CALL]]
324+
// CHECK-NEXT: return %[[ABS2]]
325+
%0 = stablehlo.abs %arg0 : tensor<8xf32>
326+
%1 = stablehlo.abs %arg0 : tensor<8xf32>
327+
%2 = call @bar(%0, %1) : (tensor<8xf32>, tensor<8xf32>) -> (tensor<8xf32>)
328+
%3 = sdy.func_data_flow_edge %2 : tensor<8xf32>
329+
%4 = stablehlo.abs %3 : tensor<8xf32>
330+
return %4 : tensor<8xf32>
331+
}
332+
333+
// -----
334+
335+
// CHECK-LABEL: @bar(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>)
336+
func.func private @bar(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<8xf32> {
337+
// CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %arg1
338+
// CHECK-NEXT: return %[[ADD]]
339+
%0 = sdy.func_data_flow_edge %arg0 : tensor<8xf32>
340+
%1 = sdy.func_data_flow_edge %arg1 : tensor<8xf32>
341+
%2 = stablehlo.add %0, %1: tensor<8xf32>
342+
return %2 : tensor<8xf32>
343+
}
344+
345+
// CHECK-LABEL: @simple_call_graph_on_func_with_multiple_argument_same_operand(%arg0: tensor<8xf32>)
346+
func.func @simple_call_graph_on_func_with_multiple_argument_same_operand(%arg0: tensor<8xf32>) -> tensor<8xf32> {
347+
// CHECK-NEXT: %[[ABS0:.*]] = stablehlo.abs %arg0
348+
// CHECK-NEXT: %[[CALL:.*]] = call @bar(%[[ABS0]], %[[ABS0]])
349+
// CHECK-NEXT: %[[ABS1:.*]] = stablehlo.abs %[[CALL]]
350+
// CHECK-NEXT: return %[[ABS1]]
351+
%0 = stablehlo.abs %arg0 : tensor<8xf32>
352+
%1 = call @bar(%0, %0) : (tensor<8xf32>, tensor<8xf32>) -> (tensor<8xf32>)
353+
%2 = sdy.func_data_flow_edge %1 : tensor<8xf32>
354+
%3 = stablehlo.abs %2 : tensor<8xf32>
355+
return %3 : tensor<8xf32>
356+
}
357+
358+
// -----
359+
360+
sdy.mesh @mesh = <["a"=2]>
361+
362+
// CHECK-LABEL: func private @bar(%arg0: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}]>})
363+
func.func private @bar(%arg0: tensor<8xf32>) -> tensor<8xf32> {
364+
// CHECK-NEXT: %[[NEGATE:.*]] = stablehlo.negate %arg0
365+
// CHECK-NEXT: return %[[NEGATE]]
366+
%0 = sdy.func_data_flow_edge %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}]>]>} : tensor<8xf32>
367+
%1 = stablehlo.negate %0: tensor<8xf32>
368+
return %1 : tensor<8xf32>
369+
}
370+
371+
// CHECK-LABEL: func @simple_call_graph_on_func_with_sharded_argument(%arg0: tensor<8xf32>)
372+
func.func @simple_call_graph_on_func_with_sharded_argument(%arg0: tensor<8xf32>) -> tensor<8xf32> {
373+
// CHECK-NEXT: %[[ABS0:.*]] = stablehlo.abs %arg0
374+
// CHECK-NEXT: %[[CALL:.*]] = call @bar(%[[ABS0]]) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}]>]>}
375+
// CHECK-NEXT: %[[ABS1:.*]] = stablehlo.abs %[[CALL]]
376+
// CHECK-NEXT: %[[ABS2:.*]] = stablehlo.abs %[[CALL]]
377+
// CHECK-NEXT: return %[[ABS1]]
378+
%0 = stablehlo.abs %arg0 : tensor<8xf32>
379+
%1 = call @bar(%0) : (tensor<8xf32>) -> (tensor<8xf32>)
380+
%2 = sdy.func_data_flow_edge %1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}]>]>} : tensor<8xf32>
381+
%3 = stablehlo.abs %2 : tensor<8xf32>
382+
%4 = stablehlo.abs %2 : tensor<8xf32>
383+
return %3 : tensor<8xf32>
384+
}
385+
386+
// -----
387+
388+
sdy.mesh @mesh = <["a"=2]>
389+
390+
// CHECK-LABEL: func private @bar(%arg0: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}]>})
391+
func.func private @bar(%arg0: tensor<8xf32>) -> tensor<8xf32> {
392+
// CHECK-NEXT: %[[NEGATE:.*]] = stablehlo.negate %arg0
393+
// CHECK-NEXT: return %[[NEGATE]]
394+
%0 = sdy.func_data_flow_edge %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}]>]>} : tensor<8xf32>
395+
%1 = stablehlo.negate %0: tensor<8xf32>
396+
return %1 : tensor<8xf32>
397+
}
398+
399+
// CHECK-LABEL: func @func_data_flow_edge_has_sharding_call_does_not(%arg0: tensor<8xf32>)
400+
func.func @func_data_flow_edge_has_sharding_call_does_not(%arg0: tensor<8xf32>) -> tensor<8xf32> {
401+
// CHECK-NEXT: %[[ABS0:.*]] = stablehlo.abs %arg0
402+
// CHECK-NEXT: %[[CALL:.*]] = call @bar(%[[ABS0]]) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{?}]>]>}
403+
// CHECK-NEXT: %[[ABS1:.*]] = stablehlo.abs %[[CALL]]
404+
// CHECK-NEXT: %[[ABS2:.*]] = stablehlo.abs %[[CALL]]
405+
// CHECK-NEXT: return %[[ABS1]]
406+
%0 = stablehlo.abs %arg0 : tensor<8xf32>
407+
%1 = call @bar(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}]>]>} : (tensor<8xf32>) -> (tensor<8xf32>)
408+
%2 = sdy.func_data_flow_edge %1 : tensor<8xf32>
409+
%3 = stablehlo.abs %2 : tensor<8xf32>
410+
%4 = stablehlo.abs %2 : tensor<8xf32>
411+
return %3 : tensor<8xf32>
412+
}
413+
414+
// -----
415+
416+
sdy.mesh @mesh = <["a"=2]>
417+
418+
// CHECK-LABEL: func private @bar(%arg0: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}]>})
419+
func.func private @bar(%arg0: tensor<8xf32>) -> tensor<8xf32> {
420+
// CHECK-NEXT: %[[NEGATE:.*]] = stablehlo.negate %arg0
421+
// CHECK-NEXT: return %[[NEGATE]]
422+
%0 = sdy.func_data_flow_edge %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}]>]>} : tensor<8xf32>
423+
%1 = stablehlo.negate %0: tensor<8xf32>
424+
return %1 : tensor<8xf32>
425+
}
426+
427+
// CHECK-LABEL: func private @foo(%arg0: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}]>})
428+
func.func private @foo(%arg0: tensor<8xf32>) -> tensor<8xf32> {
429+
// CHECK-NEXT: %[[CALL:.*]] = call @bar(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{?}]>]>}
430+
// CHECK-NEXT: %[[ABS:.*]] = stablehlo.abs %[[CALL]]
431+
// CHECK-NEXT: return %[[ABS]]
432+
%0 = sdy.func_data_flow_edge %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}]>]>} : tensor<8xf32>
433+
%1 = call @bar(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}]>]>} : (tensor<8xf32>) -> (tensor<8xf32>)
434+
%2 = sdy.func_data_flow_edge %1 : tensor<8xf32>
435+
%3 = stablehlo.abs %2 : tensor<8xf32>
436+
return %3 : tensor<8xf32>
437+
}
438+
439+
// CHECK-LABEL: func @main_calls_foo_calls_bar(%arg0: tensor<8xf32>)
440+
func.func @main_calls_foo_calls_bar(%arg0: tensor<8xf32>) -> tensor<8xf32> {
441+
// CHECK-NEXT: %[[ABS0:.*]] = stablehlo.abs %arg0
442+
// CHECK-NEXT: %[[CALL:.*]] = call @foo(%[[ABS0]]) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}]>]>}
443+
// CHECK-NEXT: %[[ABS1:.*]] = stablehlo.abs %[[CALL]]
444+
// CHECK-NEXT: return %[[ABS1]]
445+
%0 = stablehlo.abs %arg0 : tensor<8xf32>
446+
%1 = call @foo(%0) : (tensor<8xf32>) -> (tensor<8xf32>)
447+
%2 = sdy.func_data_flow_edge %1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}]>]>} : tensor<8xf32>
448+
%3 = stablehlo.abs %2 : tensor<8xf32>
449+
return %3 : tensor<8xf32>
450+
}

shardy/dialect/sdy/transforms/export/test/sink_data_flow_edges_enable_native_non_flat_support.mlir

Lines changed: 0 additions & 102 deletions
This file was deleted.

shardy/dialect/sdy/transforms/propagation/propagation_pipeline.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ void populateExportOptions(ExportOptions& options,
4646
options.avoidReshardsOnCalls = propOptions.avoidReshardsOnNamedComputations;
4747
options.updateNonDivisibleInputOutputShardings =
4848
propOptions.updateNonDivisibleInputOutputShardings;
49-
options.enableNativeNonFlatSupport = propOptions.enableNativeNonFlatSupport;
5049
}
5150

5251
} // namespace

0 commit comments

Comments
 (0)