Skip to content

Commit 3f74a70

Browse files
ekayaaslancopybara-github
authored andcommitted
Push shardy outliner up past UpdateNonDivisibleInputOutputShardings pass.
UpdateNonDivisibleInputOutputShardings updates non-divisible input output shardings for all func input/outputs, all shardable data flow ops including named-computations in/out shardings. Similarly it should also update them for call ops. PiperOrigin-RevId: 900135043
1 parent 56477a9 commit 3f74a70

20 files changed

+575
-211
lines changed

shardy/dialect/sdy/ir/constants.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ inline const std::string kEmptyMeshSymbol = "empty_mesh";
8686
// Attribute name for the original name of the func before flattening.
8787
inline constexpr llvm::StringRef kOriginalFuncName = "sdy.original_func_name";
8888

89+
// Attribute name of the main func.
90+
inline constexpr llvm::StringRef kMainFuncName = "main";
91+
8992
} // namespace sdy
9093
} // namespace mlir
9194

shardy/dialect/sdy/ir/utils.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,5 +1108,25 @@ FuncOp cloneFuncRecursively(FuncOp funcOp, SymbolTable& symbolTable) {
11081108
return clonedFuncOp;
11091109
}
11101110

1111+
TensorShardingPerValueAttr getFullyClosedLike(mlir::ValueRange values,
1112+
Attribute meshOrRef) {
1113+
SmallVector<TensorShardingAttr> resultShardings;
1114+
resultShardings.reserve(values.size());
1115+
for (mlir::Value value : values) {
1116+
resultShardings.push_back(TensorShardingAttr::getFullyReplicated(
1117+
meshOrRef.getContext(), mlir::sdy::getTensorRank(value), meshOrRef,
1118+
/*isClosed=*/true));
1119+
}
1120+
return TensorShardingPerValueAttr::get(meshOrRef.getContext(),
1121+
resultShardings);
1122+
}
1123+
1124+
// Returns the main func. Dies if there is no main func.
1125+
FuncOp getMainFuncOrDie(ModuleOp moduleOp, SymbolTable& symbolTable) {
1126+
FuncOp funcOp = symbolTable.lookup<FuncOp>(kMainFuncName);
1127+
SDY_CHECK(funcOp) << "Failed to lookup function: " << kMainFuncName.str();
1128+
return funcOp;
1129+
}
1130+
11111131
} // namespace sdy
11121132
} // namespace mlir

shardy/dialect/sdy/ir/utils.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,16 @@ Operation* getCommonSupportedReductionOp(stablehlo::ScatterOp scatter);
678678
mlir::func::FuncOp cloneFuncRecursively(func::FuncOp funcOp,
679679
SymbolTable& symbolTable);
680680

681+
// Returns a `TensorShardingPerValueAttr` on the shardings of the `values`. If
682+
// the sharding of a value is null, it creates a fully closed sharding for it on
683+
// the given `meshOrRef` and the rank of the tensor corresponding to the value.
684+
TensorShardingPerValueAttr getFullyClosedLike(mlir::ValueRange values,
685+
Attribute meshOrRef);
686+
687+
// Returns the main func. Dies if there is no main func.
688+
mlir::func::FuncOp getMainFuncOrDie(ModuleOp moduleOp,
689+
SymbolTable& symbolTable);
690+
681691
} // namespace sdy
682692
} // namespace mlir
683693

shardy/dialect/sdy/transforms/common/propagation_options.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ struct PropagationOptions {
4848
// auto-partitioner will be invoked after propagation of user-specified
4949
// shardings.
5050
bool enableAutoPartitioning = false;
51-
// Whether to avoid explicit reshards/collectives on named computations.
51+
// Whether to avoid explicit reshards/collectives on named computations/calls.
52+
// TODO(enver): Rename to avoidReshardsOnCalls.
5253
bool avoidReshardsOnNamedComputations = false;
5354
// Whether to update axes with non-divisible input/output shardings.
5455
bool updateNonDivisibleInputOutputShardings = true;

shardy/dialect/sdy/transforms/export/export_pipeline.cc

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,8 @@ void runShardyPartitioner(OpPassManager& pm, int& dumpIndex,
4040
const ExportOptions& options) {
4141
InsertExplicitReshardsPassOptions passOptions;
4242
passOptions.enableFullVersion = options.enableInsertExplicitCollectives;
43-
passOptions.avoidReshardsOnNamedComputations =
44-
options.avoidReshardsOnNamedComputations;
43+
passOptions.avoidReshardsOnCalls = options.avoidReshardsOnCalls;
4544
pm.addNestedPass<func::FuncOp>(createInsertExplicitReshardsPass(passOptions));
46-
pm.addPass(createExportNamedComputationsPass());
4745
if (options.enableInsertExplicitCollectives) {
4846
pm.addPass(mlir::sdy::createSaveModuleOpPass(
4947
options.dumpDirectory, "after_explicit_reshards", dumpIndex++));
@@ -82,6 +80,7 @@ void addExportPipeline(OpPassManager& pm, int& dumpIndex,
8280
/*sinkDebugShardingOrigins=*/options.dumpShardingOrigins,
8381
/*sinkDebugPropagationEdgeSharding=*/options.dumpPropagationEdges,
8482
/*sinkEnableNativeNonFlatSupport=*/options.enableNativeNonFlatSupport}));
83+
pm.addPass(createExportNamedComputationsPass());
8584
if (options.updateNonDivisibleInputOutputShardings) {
8685
pm.addPass(createUpdateNonDivisibleInputOutputShardingsPass());
8786
pm.addPass(createRemoveSubAxesInInputOutputShardingsPass());
@@ -98,8 +97,6 @@ void addExportPipeline(OpPassManager& pm, int& dumpIndex,
9897
// reshards/collectives.
9998
if (!options.avoidExportForPartitioning) {
10099
runShardyPartitioner(pm, dumpIndex, options);
101-
} else {
102-
pm.addPass(createExportNamedComputationsPass());
103100
}
104101
if (options.dumpPropagationEdges || options.dumpShardingOrigins) {
105102
pm.addPass(createRemovePropagationDebugInfoPass());

shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ namespace sdy {
5151

5252
namespace {
5353

54+
using func::CallOp;
55+
using func::FuncOp;
56+
5457
void insertExplicitReshardsToTargetSharding(OpOperand& opOperand,
5558
TensorShardingAttr targetSharding,
5659
IRRewriter& rewriter,
@@ -102,22 +105,10 @@ void insertExplicitReshardsOnFuncReturn(Operation* op, func::FuncOp& funcOp,
102105
}
103106
}
104107

105-
void insertExplicitReshardsOnDataFlowOp(
106-
ShardableDataFlowOpInterface& op, IRRewriter& rewriter,
107-
const SymbolTable& symbolTable, const bool onFullVersion,
108-
const bool avoidReshardsOnNamedComputations) {
109-
if (isa<NamedComputationOp>(op) && avoidReshardsOnNamedComputations) {
110-
for (Value owner : op.getOpResultEdgeOwners()) {
111-
for (OpOperand* sourceOpOperand : op.getEdgeSources(owner)) {
112-
insertExplicitReshardsToTargetSharding(
113-
*sourceOpOperand,
114-
/*targetSharding=*/op.getEdgeOwnerSharding(owner), rewriter,
115-
symbolTable,
116-
/*insertAfterOperand=*/true, onFullVersion);
117-
}
118-
}
119-
return;
120-
}
108+
void insertExplicitReshardsOnDataFlowOp(ShardableDataFlowOpInterface& op,
109+
IRRewriter& rewriter,
110+
const SymbolTable& symbolTable,
111+
const bool onFullVersion) {
121112
for (Value owner : llvm::concat<Value>(op.getOpResultEdgeOwners(),
122113
op.getBlockArgumentEdgeOwners())) {
123114
TensorShardingAttr ownerSharding = op.transformTargetSharding(
@@ -132,6 +123,33 @@ void insertExplicitReshardsOnDataFlowOp(
132123
}
133124
}
134125

126+
void insertExplicitReshardsOnCallOp(CallOp callOp, IRRewriter& rewriter,
127+
const SymbolTable& symbolTable,
128+
const bool onFullVersion) {
129+
FuncOp funcOp = symbolTable.lookup<FuncOp>(callOp.getCallee());
130+
TensorShardingPerValueAttr funcArgShardings =
131+
mlir::sdy::getFuncArgShardings(funcOp, symbolTable);
132+
if (!funcArgShardings) {
133+
mlir::Attribute meshOrRef = getMeshOrRef(
134+
callOp.getNumOperands(), symbolTable,
135+
[&](int64_t i) { return getSharding(callOp.getOperand(i)); });
136+
// Return without inserting reshards as neither func arguments nor call
137+
// operands have a sharding with non-maximal mesh.
138+
if (!meshOrRef) {
139+
return;
140+
}
141+
funcArgShardings = getFullyClosedLike(callOp.getOperands(), meshOrRef);
142+
}
143+
rewriter.setInsertionPoint(callOp);
144+
for (auto [funcArgSharding, sourceOpOperand] : llvm::zip_equal(
145+
funcArgShardings.getShardings(), callOp->getOpOperands())) {
146+
insertExplicitReshardsToTargetSharding(
147+
sourceOpOperand,
148+
/*targetSharding=*/funcArgSharding, rewriter, symbolTable,
149+
/*insertAfterOperand=*/true, onFullVersion);
150+
}
151+
}
152+
135153
// Reshard the result of a dot operation if all the following hold:
136154
//
137155
// 1. LHS and RHS have fully compatible shardings.
@@ -382,7 +400,7 @@ bool isOnFullVersion(Operation* op, const bool enableFullVersion) {
382400
}
383401
// To avoid copies of the same functions with mismatching shardings on the
384402
// arguments onto multiple callsites.
385-
if (isa<NamedComputationOp>(op)) {
403+
if (isa<func::CallOp>(op)) {
386404
return true;
387405
}
388406

@@ -472,8 +490,15 @@ struct InsertExplicitReshardsPass
472490
// TODO(enver): Prefer resharding the owner when multiple sources are
473491
// sharded in the same way.
474492
insertExplicitReshardsOnDataFlowOp(shardableDataFlowOp, rewriter,
475-
symbolTable, onFullVersion,
476-
avoidReshardsOnNamedComputations);
493+
symbolTable, onFullVersion);
494+
return;
495+
}
496+
497+
if (CallOp callOp = dyn_cast<CallOp>(op)) {
498+
if (!avoidReshardsOnCalls) {
499+
insertExplicitReshardsOnCallOp(callOp, rewriter, symbolTable,
500+
onFullVersion);
501+
}
477502
return;
478503
}
479504

shardy/dialect/sdy/transforms/export/passes.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,10 @@ struct ExportOptions : public PassPipelineOptions<ExportOptions> {
7676
llvm::cl::desc("Sink sdy.propagation_edges attr."),
7777
llvm::cl::init(false)};
7878

79-
Option<bool> avoidReshardsOnNamedComputations{
80-
*this, "avoid-reshards-on-named-computations",
81-
llvm::cl::desc("Avoid inserting explicit reshards/collectives for named "
82-
"computations."),
79+
Option<bool> avoidReshardsOnCalls{
80+
*this, "avoid-reshards-on-calls",
81+
llvm::cl::desc(
82+
"Avoid inserting explicit reshards/collectives for calls."),
8383
llvm::cl::init(false)};
8484

8585
Option<bool> updateNonDivisibleInputOutputShardings{

shardy/dialect/sdy/transforms/export/passes.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,10 @@ def InsertExplicitReshardsPass : Pass<"sdy-insert-explicit-reshards", "func::Fun
128128
Option<"enableFullVersion", "enable-full-version",
129129
"bool", /*default=*/"false",
130130
"Enable full version.">,
131-
Option<"avoidReshardsOnNamedComputations",
132-
"avoid-reshards-on-named-computations",
131+
Option<"avoidReshardsOnCalls",
132+
"avoid-reshards-on-calls",
133133
"bool", /*default=*/"false",
134-
"Avoid explicit reshards/collectives on named computations.">
134+
"Avoid explicit reshards/collectives on calls.">
135135
];
136136
}
137137

shardy/dialect/sdy/transforms/export/remove_sub_axes_in_input_output_shardings.cc

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ limitations under the License.
2323
#include "mlir/IR/BuiltinOps.h"
2424
#include "mlir/IR/BuiltinTypeInterfaces.h"
2525
#include "mlir/IR/MLIRContext.h"
26+
#include "mlir/IR/SymbolTable.h"
2627
#include "mlir/Pass/Pass.h" // IWYU pragma: keep
2728
#include "mlir/Support/LLVM.h"
2829
#include "mlir/Transforms/DialectConversion.h"
@@ -82,22 +83,23 @@ struct RemoveSubAxesInInputOutputShardingsPass
8283
RemoveSubAxesInInputOutputShardingsPassBase;
8384

8485
void runOnOperation() final {
85-
for (auto funcOp : getOperation().getOps<func::FuncOp>()) {
86-
// Update arguments.
87-
updateValueShardings(
88-
funcOp.getNumArguments(),
89-
[&](int64_t index) { return getSharding(funcOp.getArgument(index)); },
90-
[&](int64_t index, TensorShardingAttr sharding) {
91-
setSharding(funcOp.getArgument(index), sharding);
92-
});
93-
// Update results.
94-
updateValueShardings(
95-
funcOp.getNumResults(),
96-
[&](int64_t index) { return getFuncResultSharding(funcOp, index); },
97-
[&](int64_t index, TensorShardingAttr sharding) {
98-
setFuncResultSharding(funcOp, index, sharding);
99-
});
100-
}
86+
ModuleOp moduleOp = getOperation();
87+
SymbolTable symbolTable(moduleOp);
88+
func::FuncOp funcOp = getMainFuncOrDie(moduleOp, symbolTable);
89+
// Update arguments.
90+
updateValueShardings(
91+
funcOp.getNumArguments(),
92+
[&](int64_t index) { return getSharding(funcOp.getArgument(index)); },
93+
[&](int64_t index, TensorShardingAttr sharding) {
94+
setSharding(funcOp.getArgument(index), sharding);
95+
});
96+
// Update results.
97+
updateValueShardings(
98+
funcOp.getNumResults(),
99+
[&](int64_t index) { return getFuncResultSharding(funcOp, index); },
100+
[&](int64_t index, TensorShardingAttr sharding) {
101+
setFuncResultSharding(funcOp, index, sharding);
102+
});
101103
}
102104
};
103105

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
// RUN: sdy_opt %s -split-input-file -sdy-insert-explicit-reshards='enable-full-version=true avoid-reshards-on-calls=true' | FileCheck %s
2+
3+
sdy.mesh @mesh = <["x"=2, "y"=2, "z"=4]>
4+
5+
// CHECK-LABEL: func @call
6+
func.func @call(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}]>}) {
7+
// CHECK-NEXT: %[[CALL:.*]] = call @foo(%arg0)
8+
// CHECK-NEXT: %[[NEGATE:.*]] = stablehlo.negate %[[CALL]]
9+
// CHECK-NEXT: return %[[NEGATE]]
10+
%0 = call @foo(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}]>]>} : (tensor<210xf32>) -> (tensor<210xf32>)
11+
%1 = stablehlo.negate %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}]>]>} : tensor<210xf32>
12+
return %1 : tensor<210xf32>
13+
}
14+
15+
// CHECK-LABEL: func private @foo
16+
func.func private @foo(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}]>}) {
17+
// CHECK-NEXT: %[[ABS:.*]] = stablehlo.abs %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>}
18+
// CHECK-NEXT: %[[RESHARD:.*]] = sdy.reshard %[[ABS]] <@mesh, [{"z"}]>
19+
// CHECK-NEXT: return %[[RESHARD]]
20+
%0 = stablehlo.abs %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32>
21+
return %0 : tensor<210xf32>
22+
}
23+
24+
// -----
25+
sdy.mesh @mesh = <["x"=4, "y"=2]>
26+
27+
// CHECK-LABEL: func @call_empty_block
28+
func.func @call_empty_block(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}]>}) {
29+
// CHECK-NEXT: %[[CALL:.*]] = call @foo(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>}
30+
// CHECK-NEXT: %[[NEGATE:.*]] = stablehlo.negate %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>}
31+
%0 = call @foo(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : (tensor<210xf32>) -> (tensor<210xf32>)
32+
%1 = stablehlo.negate %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32>
33+
return %1 : tensor<210xf32>
34+
}
35+
36+
// CHECK-LABEL: func private @foo
37+
func.func private @foo(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}]>}) {
38+
// CHECK-NEXT: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh, [{"y"}]>
39+
// CHECK-NEXT: return %[[RESHARD]]
40+
return %arg0 : tensor<210xf32>
41+
}
42+
43+
// -----
44+
sdy.mesh @mesh = <["a"=2, "b"=2, "c"=2]>
45+
46+
// CHECK-LABEL: func @call_with_shardings
47+
func.func @call_with_shardings(%arg0: tensor<8x2xi32>, %arg1: tensor<4x2xi32>) -> tensor<12x2xi32> {
48+
// CHECK-NEXT: %[[CALL:.*]]:2 = call @foo(%arg0, %arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}]>, <@mesh, [{}, {}]>]>}
49+
// CHECK-NEXT: %[[RESHARD1:.*]] = sdy.reshard %[[CALL]]#0 <@mesh, [{}, {"a"}]>
50+
// CHECK-NEXT: %[[RESHARD2:.*]] = sdy.reshard %[[CALL]]#1 <@mesh, [{}, {"a"}]>
51+
// CHECK-NEXT: %[[CONCAT:.*]] = stablehlo.concatenate %[[RESHARD1]], %[[RESHARD2]], dim = 0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"a"}]>]>}
52+
// CHECK-NEXT: %[[RESHARD3:.*]] = sdy.reshard %[[CONCAT]] <@mesh, [{}, {}]>
53+
%0:2 = call @foo(%arg0, %arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}]>, <@mesh, [{}, {}]>]>} : (tensor<8x2xi32>, tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>)
54+
%1 = stablehlo.concatenate %0#0, %0#1, dim = 0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"a"}]>]>} : (tensor<8x2xi32>, tensor<4x2xi32>) -> tensor<12x2xi32>
55+
return %1 : tensor<12x2xi32>
56+
}
57+
58+
// CHECK-LABEL: func private @foo
59+
func.func private @foo(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}, %arg1: tensor<4x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>})
60+
-> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}, tensor<4x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>}) {
61+
// CHECK-NEXT: %[[ABS:.*]] = stablehlo.abs %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}]>]>}
62+
// CHECK-NEXT: %[[RESHARD0:.*]] = sdy.reshard %[[ABS]] <@mesh, [{}, {"a"}]>
63+
// CHECK-NEXT: %[[RESHARD1:.*]] = sdy.reshard %[[RESHARD0]] <@mesh, [{"a"}, {}]>
64+
// CHECK-NEXT: return %[[RESHARD1]], %arg1
65+
%0 = stablehlo.abs %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"a"}]>]>} : tensor<8x2xi32>
66+
return %0, %arg1 : tensor<8x2xi32>, tensor<4x2xi32>
67+
}
68+
69+
// -----
70+
sdy.mesh @mesh = <["x"=2, "y"=2, "z"=4]>
71+
72+
// CHECK-LABEL: func @one_argument_to_multiple_calls(
73+
func.func @one_argument_to_multiple_calls(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}]>}) {
74+
// CHECK-NEXT: %[[CALL0:.*]] = call @foo(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>}
75+
// CHECK-NEXT: %[[CALL1:.*]] = call @bar(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}]>]>}
76+
// CHECK-NEXT: %[[RESHARD:.*]] = sdy.reshard %[[CALL0]] <@mesh, [{"z"}]>
77+
// CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[RESHARD]], %[[CALL1]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}]>]>}
78+
// CHECK-NEXT: return %[[ADD]]
79+
%0 = call @foo(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : (tensor<210xf32>) -> (tensor<210xf32>)
80+
%1 = call @bar(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}]>]>} : (tensor<210xf32>) -> (tensor<210xf32>)
81+
%3 = stablehlo.add %0, %1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}]>]>} : tensor<210xf32>
82+
return %3 : tensor<210xf32>
83+
}
84+
85+
// CHECK-LABEL: func private @foo
86+
func.func private @foo(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}]>}) {
87+
// CHECK-NEXT: %[[ABS:.*]] = stablehlo.abs %arg0
88+
// CHECK-NEXT: return %[[ABS]]
89+
%0 = stablehlo.abs %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32>
90+
return %0 : tensor<210xf32>
91+
}
92+
93+
// CHECK-LABEL: func private @bar
94+
func.func private @bar(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}]>}) {
95+
// CHECK-NEXT: %[[ABS:.*]] = stablehlo.abs %arg0
96+
// CHECK-NEXT: return %[[ABS]]
97+
%0 = stablehlo.abs %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}]>]>} : tensor<210xf32>
98+
return %0 : tensor<210xf32>
99+
}
100+
101+
// -----
102+
sdy.mesh @mesh = <["x"=2, "y"=2, "z"=4]>
103+
104+
105+
// CHECK-LABEL: func @different_arguments_to_multiple_calls_with_same_input_output_shardings
106+
func.func @different_arguments_to_multiple_calls_with_same_input_output_shardings(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}]>}) {
107+
// CHECK-NEXT: %[[CALL0:.*]] = call @foo(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>}
108+
// CHECK-NEXT: %[[NEGATE:.*]] = stablehlo.negate %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}]>]>}
109+
// CHECK-NEXT: %[[RESHARD:.*]] = sdy.reshard %[[NEGATE]] <@mesh, [{"y"}]>
110+
// CHECK-NEXT: %[[CALL1:.*]] = call @foo(%[[RESHARD]]) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>}
111+
// CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[CALL0]], %[[CALL1]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>}
112+
// CHECK-NEXT: return %[[ADD]]
113+
%0 = call @foo(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : (tensor<210xf32>) -> (tensor<210xf32>)
114+
%1 = stablehlo.negate %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32>
115+
%2 = call @foo(%1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : (tensor<210xf32>) -> (tensor<210xf32>)
116+
%4 = stablehlo.add %0, %2 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32>
117+
return %4 : tensor<210xf32>
118+
}
119+
120+
// CHECK-LABEL: func private @foo
121+
func.func private @foo(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}]>}) {
122+
// CHECK-NEXT: %[[ABS:.*]] = stablehlo.abs %arg0
123+
// CHECK-NEXT: return %[[ABS]]
124+
%3 = stablehlo.abs %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32>
125+
return %3 : tensor<210xf32>
126+
}

0 commit comments

Comments
 (0)