Skip to content

Commit 1e3e36d

Browse files
ekayaaslancopybara-github
authored andcommitted
Push shardy outliner up out of export pipeline.
PiperOrigin-RevId: 900660386
1 parent 8e6c383 commit 1e3e36d

3 files changed

Lines changed: 8 additions & 10 deletions

File tree

shardy/dialect/sdy/transforms/export/export_pipeline.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ void runShardyPartitioner(OpPassManager& pm, int& dumpIndex,
7070

7171
void addExportPipeline(OpPassManager& pm, int& dumpIndex,
7272
const ExportOptions& options) {
73-
pm.addPass(createExportNamedComputationsPass());
7473
pm.addNestedPass<func::FuncOp>(createConstantOrScalarMergerPass());
7574
if (!options.avoidExportForPartitioning) {
7675
pm.addPass(createRemoveShardingGroupsPass());

shardy/dialect/sdy/transforms/export/test/export_pipeline.mlir

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ func.func @main(
6464
// -----
6565
sdy.mesh @mesh = <["a"=2, "b"=2, "c"=2]>
6666

67-
// Test: named_computation_with_shardings
67+
// Test: call_with_shardings
6868
// CHECK-LABEL: func @main
6969
func.func @main(%arg0: tensor<8x2xi32>, %arg1: tensor<4x2xi32>) -> tensor<12x2xi32> {
7070
// CHECK-NEXT: %0 = sdy.reshard %arg0 <@mesh, [{"a"}, {}]> : tensor<8x2xi32>
@@ -73,15 +73,13 @@ func.func @main(%arg0: tensor<8x2xi32>, %arg1: tensor<4x2xi32>) -> tensor<12x2xi
7373
// CHECK-NEXT: %3 = sdy.reshard %1#1 <@mesh, [{}, {"a"}]> : tensor<4x2xi32>
7474
// CHECK-NEXT: %4 = stablehlo.concatenate %2, %3, dim = 0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"a"}]>]>}
7575
// CHECK-NEXT: return %4 : tensor<12x2xi32>
76-
%0:2 = sdy.named_computation<"foo">(%arg0, %arg1) in_shardings=[<@mesh, [{"a"}, {}]>, <@mesh, [{?}, {}]>] out_shardings=[<@mesh, [{"a"}, {}]>, <@mesh, [{?}, {}]>] (%arg2: tensor<8x2xi32>, %arg3: tensor<4x2xi32>) {
77-
sdy.return %arg2, %arg3 : tensor<8x2xi32>, tensor<4x2xi32>
78-
} : (tensor<8x2xi32>, tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>)
76+
%0:2 = call @foo(%arg0, %arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}]>, <@mesh, [{}, {}]>]>} : (tensor<8x2xi32>, tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>)
7977
%1 = stablehlo.concatenate %0#0, %0#1, dim=0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"a"}]>]>} : (tensor<8x2xi32>, tensor<4x2xi32>) -> tensor<12x2xi32>
8078
return %1 : tensor<12x2xi32>
8179
}
8280

83-
// CHECK-LABEL: func private @foo(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}, %arg1: tensor<4x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>})
84-
// CHECK-SAME: -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}, tensor<4x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>})
85-
// CHECK-SAME: attributes {sdy.original_func_name = "foo"} {
86-
// CHECK-NEXT: return %arg0, %arg1 : tensor<8x2xi32>, tensor<4x2xi32>
87-
// CHECK-NEXT: }
81+
func.func private @foo(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}, %arg1: tensor<4x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>})
82+
-> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}, tensor<4x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>})
83+
attributes {sdy.original_func_name = "foo"} {
84+
return %arg0, %arg1 : tensor<8x2xi32>, tensor<4x2xi32>
85+
}

shardy/dialect/sdy/transforms/propagation/propagation_pipeline.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ void addPropagationPipeline(OpPassManager& pm, int& dumpIndex,
7474
dumpIndex++));
7575
AutoPartitionerRegistry::addPasses(pm);
7676
}
77+
pm.addPass(createExportNamedComputationsPass());
7778
ExportOptions exportOptions;
7879
populateExportOptions(exportOptions, options);
7980
addExportPipeline(pm, dumpIndex, exportOptions);

0 commit comments

Comments
 (0)