Skip to content

Commit dfb6bde

Browse files
ekayaaslancopybara-github
authored andcommitted
Add pass to propagate shardings from func results to call before shardy inliner.
This change introduces a new pass that sets the result shardings of a call to match the result shardings of its callee function, but only if the func.call does not already have a sharding. This way ImportFuncCalls pass does not need to care about it. It simplifies pushing ImportFuncCalls pass further down. PiperOrigin-RevId: 900714449
1 parent 356dda9 commit dfb6bde

File tree

7 files changed

+196
-7
lines changed

7 files changed

+196
-7
lines changed

shardy/dialect/sdy/transforms/import/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ cc_library(
4141
"inline_meshes.cc",
4242
"lift_inlined_meshes.cc",
4343
"manual_axes_cleanup.cc",
44+
"propagate_sharding_from_func_to_call.cc",
4445
"remove_size_one_axes.cc",
4546
"sharding_group_import.cc",
4647
],

shardy/dialect/sdy/transforms/import/import_func_calls.cc

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,17 +63,13 @@ void importCallOp(
6363
SDY_CHECK(funcOp) << "Failed to lookup function: " << calleeName.str();
6464

6565
rewriter.setInsertionPoint(callOp);
66-
TensorShardingPerValueAttr callOpResultShardings =
67-
getShardingPerValue(callOp);
6866
auto namedCompOp = NamedComputationOp::create(
6967
rewriter, callOp->getLoc(), callOp->getResultTypes(),
7068
getOriginalFuncName(funcOp), callOp.getOperands(),
7169
/*inShardings=*/getFuncArgShardings(funcOp, symbolTable),
7270
// TODO(b/439018088): Take func result shardings if call op result
7371
// shardings are empty.
74-
/*outShardings=*/
75-
callOpResultShardings ? callOpResultShardings
76-
: getFuncResultShardings(funcOp, symbolTable));
72+
/*outShardings=*/getShardingPerValue(callOp));
7773
namedCompOp->setAttrs(namedCompAttrs);
7874

7975
Region& namedCompRegion = namedCompOp.getRegion();

shardy/dialect/sdy/transforms/import/import_pipeline.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ void addImportPipeline(OpPassManager& pm, int& dumpIndex,
3030
pm.addPass(createSymbolDCEPass());
3131
pm.addPass(createLiftInlinedMeshesPass());
3232
pm.addPass(createRemoveSizeOneAxesPass());
33+
pm.addPass(createPropagateShardingFromFuncToCallPass());
3334
pm.addPass(createImportFuncCallsPass());
3435
// Keep SymbolDCEPass after ImportFuncCallsPass.
3536
pm.addPass(createSymbolDCEPass());

shardy/dialect/sdy/transforms/import/passes.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,17 @@ def ImportFuncCallsPass : Pass<"sdy-import-func-calls", "ModuleOp"> {
2828
let dependentDialects = ["mlir::sdy::SdyDialect"];
2929
}
3030

31+
def PropagateShardingFromFuncToCallPass : Pass<"sdy-propagate-sharding-from-func-to-call", "ModuleOp"> {
32+
let summary = "Set call result shardings as the func result shardings, if empty.";
33+
let description = [{
34+
Creates a pass to propagate func result sharding to call result sharding if
35+
call does not have them and func does. Notably, it keeps call result
36+
sharding if the call already has result shardings, even if all individual
37+
result shardings are empty.
38+
}];
39+
let dependentDialects = ["mlir::sdy::SdyDialect"];
40+
}
41+
3142
def LiftInlinedMeshesPass : Pass<"sdy-lift-inlined-meshes", "ModuleOp"> {
3243
let summary = "Lifts inlined `MeshAttr`s in shardings as symbol `MeshOp`s.";
3344
let description = [{
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/* Copyright 2026 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
#include "mlir/Dialect/Func/IR/FuncOps.h"
16+
#include "mlir/IR/BuiltinOps.h"
17+
#include "mlir/IR/SymbolTable.h"
18+
#include "shardy/dialect/sdy/ir/dialect.h"
19+
#include "shardy/dialect/sdy/ir/utils.h"
20+
#include "shardy/dialect/sdy/transforms/import/passes.h" // IWYU pragma: keep
21+
22+
namespace mlir {
23+
namespace sdy {
24+
25+
#define GEN_PASS_DEF_PROPAGATESHARDINGFROMFUNCTOCALLPASS
26+
#include "shardy/dialect/sdy/transforms/import/passes.h.inc"
27+
28+
namespace {
29+
30+
using func::CallOp;
31+
using func::FuncOp;
32+
33+
struct PropagateShardingFromFuncToCallPass
34+
: public impl::PropagateShardingFromFuncToCallPassBase<
35+
PropagateShardingFromFuncToCallPass> {
36+
using PropagateShardingFromFuncToCallPassBase::
37+
PropagateShardingFromFuncToCallPassBase;
38+
39+
void runOnOperation() override {
40+
ModuleOp moduleOp = getOperation();
41+
SymbolTable symbolTable(moduleOp);
42+
43+
// Propagate shardings from func results to call results if call does not
44+
// have them and func does.
45+
moduleOp.walk([&](CallOp callOp) {
46+
FuncOp funcOp = getFuncOpOrDie(callOp.getCallee(), symbolTable);
47+
if (!getShardingPerValue(callOp)) {
48+
if (TensorShardingPerValueAttr funcResultShardings =
49+
getFuncResultShardings(funcOp, symbolTable);
50+
funcResultShardings) {
51+
setShardings(callOp, funcResultShardings);
52+
}
53+
}
54+
});
55+
}
56+
};
57+
58+
} // namespace
59+
60+
} // namespace sdy
61+
} // namespace mlir

shardy/dialect/sdy/transforms/import/test/import_func_calls.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ sdy.mesh @mesh = #sdy.mesh<["x"=2, "y"=2]>
433433

434434
// CHECK-LABEL: func @func_has_out_sharding_call_no_out_sharding
435435
func.func @func_has_out_sharding_call_no_out_sharding(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> tensor<8x2xi32> {
436-
// CHECK-NEXT: %[[NC:.*]] = sdy.named_computation<"foo">(%arg0, %arg0) out_shardings=[<@mesh, [{"x", "y"}, {}]>] (%arg1: tensor<8x2xi32>, %arg2: tensor<8x2xi32>) {
436+
// CHECK-NEXT: %[[NC:.*]] = sdy.named_computation<"foo">(%arg0, %arg0) (%arg1: tensor<8x2xi32>, %arg2: tensor<8x2xi32>) {
437437
// CHECK-NEXT: %[[MULTIPLY:.*]] = stablehlo.multiply %arg1, %arg2 : tensor<8x2xi32>
438438
// CHECK-NEXT: sdy.return %[[MULTIPLY]] : tensor<8x2xi32>
439439
// CHECK-NEXT: } {mhlo.frontend_attributes = {inlineable = "false"}} : (tensor<8x2xi32>, tensor<8x2xi32>) -> tensor<8x2xi32>
@@ -529,7 +529,7 @@ sdy.mesh @mesh = #sdy.mesh<["x"=2, "y"=2]>
529529

530530
// CHECK-LABEL: func @func_has_out_sharding_on_one_result_call_has_no_out_sharding
531531
func.func @func_has_out_sharding_on_one_result_call_has_no_out_sharding(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> tensor<8x2xi32> {
532-
// CHECK-NEXT: %[[NC:.*]]:2 = sdy.named_computation<"foo">(%arg0, %arg0) out_shardings=[<@mesh, [{?}, {?}]>, <@mesh, [{"y"}, {}]>] (%arg1: tensor<8x2xi32>, %arg2: tensor<8x2xi32>) {
532+
// CHECK-NEXT: %[[NC:.*]]:2 = sdy.named_computation<"foo">(%arg0, %arg0) (%arg1: tensor<8x2xi32>, %arg2: tensor<8x2xi32>) {
533533
// CHECK-NEXT: %[[MULTIPLY:.*]] = stablehlo.multiply %arg1, %arg2 : tensor<8x2xi32>
534534
// CHECK-NEXT: %[[TRANSPOSE:.*]] = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<8x2xi32>) -> tensor<2x8xi32>
535535
// CHECK-NEXT: %[[DOT_1:.*]] = stablehlo.dot %[[TRANSPOSE]], %arg1 : (tensor<2x8xi32>, tensor<8x2xi32>) -> tensor<2x2xi32>
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
// RUN: sdy_opt %s -sdy-propagate-sharding-from-func-to-call -split-input-file | FileCheck %s
2+
3+
sdy.mesh @mesh = #sdy.mesh<["x"=2, "y"=2]>
4+
5+
// CHECK-LABEL: func @propagate_func_to_call
6+
func.func @propagate_func_to_call(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> {
7+
// CHECK-NEXT: %0 = call @foo(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32>
8+
%0 = call @foo(%arg0) : (tensor<8x2xi32>) -> tensor<8x2xi32>
9+
return %0 : tensor<8x2xi32>
10+
}
11+
12+
func.func private @foo(%arg0: tensor<8x2xi32>) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
13+
return %arg0 : tensor<8x2xi32>
14+
}
15+
16+
// -----
17+
18+
sdy.mesh @mesh = #sdy.mesh<["x"=2, "y"=2]>
19+
20+
// CHECK-LABEL: func @do_not_overwrite_call_sharding
21+
func.func @do_not_overwrite_call_sharding(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> {
22+
// CHECK-NEXT: %0 = call @foo(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}, {"x"}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32>
23+
%0 = call @foo(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}, {"x"}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32>
24+
return %0 : tensor<8x2xi32>
25+
}
26+
27+
func.func private @foo(%arg0: tensor<8x2xi32>) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
28+
return %arg0 : tensor<8x2xi32>
29+
}
30+
31+
// -----
32+
33+
sdy.mesh @mesh = #sdy.mesh<["x"=2, "y"=2]>
34+
35+
// CHECK-LABEL: func @both_call_and_func_has_empty_result_shardings
36+
func.func @both_call_and_func_has_empty_result_shardings(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> {
37+
// CHECK-NEXT: %0 = call @foo(%arg0) : (tensor<8x2xi32>) -> tensor<8x2xi32>
38+
%0 = call @foo(%arg0) : (tensor<8x2xi32>) -> tensor<8x2xi32>
39+
return %0 : tensor<8x2xi32>
40+
}
41+
42+
func.func private @foo(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> {
43+
return %arg0 : tensor<8x2xi32>
44+
}
45+
46+
// -----
47+
48+
sdy.mesh @mesh = #sdy.mesh<["x"=2, "y"=2]>
49+
50+
// CHECK-LABEL: func @multiple_results
51+
func.func @multiple_results(%arg0: tensor<8x2xi32>, %arg1: tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>) {
52+
// CHECK-NEXT: %0:2 = call @foo(%arg0, %arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>, <@mesh, [{"y"}, {}]>]>} : (tensor<8x2xi32>, tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>)
53+
%0:2 = call @foo(%arg0, %arg1) : (tensor<8x2xi32>, tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>)
54+
return %0#0, %0#1 : tensor<8x2xi32>, tensor<4x2xi32>
55+
}
56+
57+
func.func private @foo(%arg0: tensor<8x2xi32>, %arg1: tensor<4x2xi32>) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}, tensor<4x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}) {
58+
return %arg0, %arg1 : tensor<8x2xi32>, tensor<4x2xi32>
59+
}
60+
61+
// -----
62+
63+
sdy.mesh @mesh = #sdy.mesh<["x"=2, "y"=2]>
64+
65+
// CHECK-LABEL: func @keep_empty_call_sharding
66+
func.func @keep_empty_call_sharding(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> {
67+
// CHECK-NEXT: %0 = call @foo(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32>
68+
%0 = call @foo(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32>
69+
return %0 : tensor<8x2xi32>
70+
}
71+
72+
func.func private @foo(%arg0: tensor<8x2xi32>) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
73+
return %arg0 : tensor<8x2xi32>
74+
}
75+
76+
// -----
77+
78+
sdy.mesh @mesh = #sdy.mesh<["x"=2, "y"=2]>
79+
80+
// CHECK-LABEL: func @multiple_results_one_same_one_is_empty
81+
func.func @multiple_results_one_same_one_is_empty(%arg0: tensor<8x2xi32>, %arg1: tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>) {
82+
// CHECK-NEXT: %0:2 = call @foo(%arg0, %arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>, <@mesh, [{}, {}]>]>} : (tensor<8x2xi32>, tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>)
83+
%0:2 = call @foo(%arg0, %arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>, <@mesh, [{}, {}]>]>} : (tensor<8x2xi32>, tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>)
84+
return %0#0, %0#1 : tensor<8x2xi32>, tensor<4x2xi32>
85+
}
86+
87+
func.func private @foo(%arg0: tensor<8x2xi32>, %arg1: tensor<4x2xi32>) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}, tensor<4x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}) {
88+
return %arg0, %arg1 : tensor<8x2xi32>, tensor<4x2xi32>
89+
}
90+
91+
// -----
92+
93+
sdy.mesh @mesh = #sdy.mesh<["x"=2, "y"=2]>
94+
95+
// CHECK-LABEL: func @multiple_results_one_different_one_is_empty
96+
func.func @multiple_results_one_different_one_is_empty(%arg0: tensor<8x2xi32>, %arg1: tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>) {
97+
// CHECK-NEXT: %0:2 = call @foo(%arg0, %arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}, {"x"}]>, <@mesh, [{}, {}]>]>} : (tensor<8x2xi32>, tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>)
98+
%0:2 = call @foo(%arg0, %arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}, {"x"}]>, <@mesh, [{}, {}]>]>} : (tensor<8x2xi32>, tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>)
99+
return %0#0, %0#1 : tensor<8x2xi32>, tensor<4x2xi32>
100+
}
101+
102+
func.func private @foo(%arg0: tensor<8x2xi32>, %arg1: tensor<4x2xi32>) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}, tensor<4x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}) {
103+
return %arg0, %arg1 : tensor<8x2xi32>, tensor<4x2xi32>
104+
}
105+
106+
// -----
107+
108+
sdy.mesh @mesh = #sdy.mesh<["x"=2, "y"=2]>
109+
110+
// CHECK-LABEL: func @multiple_results_call_no_sharding_func_has_sharding_on_one_no_sharding_on_the_other
111+
func.func @multiple_results_call_no_sharding_func_has_sharding_on_one_no_sharding_on_the_other(%arg0: tensor<8x2xi32>, %arg1: tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>) {
112+
// CHECK-NEXT: %0:2 = call @foo(%arg0, %arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>, <@mesh, [{?}, {?}]>]>} : (tensor<8x2xi32>, tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>)
113+
%0:2 = call @foo(%arg0, %arg1) : (tensor<8x2xi32>, tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>)
114+
return %0#0, %0#1 : tensor<8x2xi32>, tensor<4x2xi32>
115+
}
116+
117+
func.func private @foo(%arg0: tensor<8x2xi32>, %arg1: tensor<4x2xi32>) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}, tensor<4x2xi32>) {
118+
return %arg0, %arg1 : tensor<8x2xi32>, tensor<4x2xi32>
119+
}

0 commit comments

Comments
 (0)