Skip to content

Commit b69bd91

Browse files
petebucopybara-github
authored andcommitted
Copy UniquifyFunctionInputsOutputsPass to UniquifyAndMergeReturnsPass
This is a preparatory step to modify the new pass to also merge returns. The original pass and tests are copied and renamed to allow coexistence. PiperOrigin-RevId: 908184894
1 parent 6954b86 commit b69bd91

21 files changed

Lines changed: 559 additions & 92 deletions

shardy/dialect/mpmd/ir/ops.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,8 @@ def FragmentCallOp : Mpmd_Op<"fragment_call",
257257
Variadic<Mpmd_MeshTensorType>:$tensors,
258258
TypedArrayAttrBase<Mpmd_UserOrigin, "array of origin infos">:$origin,
259259
StrAttr:$mesh_name,
260-
FlatSymbolRefAttr:$callee);
260+
FlatSymbolRefAttr:$callee,
261+
OptionalAttr<ArrayAttr>:$inferred_by);
261262
let results = (outs Variadic<Mpmd_MeshTensorType>);
262263

263264
let extraClassDeclaration = [{

shardy/dialect/mpmd/ir/utils.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ SmallVector<MpmdDataflowEdge> GetMpmdDataflowEdges(FuncOp func_op) {
379379

380380
FragmentOp WrapOpWithFragment(
381381
Operation* op, StringRef mesh_name, RewriterBase& rewriter,
382-
std::function<bool(OpOperand&)> should_replace_use) {
382+
StringRef inferred_by, std::function<bool(OpOperand&)> should_replace_use) {
383383
// We set the insertion point right before `op` so assigns of operands will be
384384
// in the right place regardless of previous insertion point.
385385
rewriter.setInsertionPoint(op);
@@ -437,6 +437,10 @@ FragmentOp WrapOpWithFragment(
437437
return block_builder.clone(*op, mapping)->getResults();
438438
});
439439

440+
fragment_op->setAttr(
441+
kInferredByAttr,
442+
rewriter.getArrayAttr({rewriter.getStringAttr(inferred_by)}));
443+
440444
// Unassign all fragment results and replace all uses of `op` with the
441445
// corresponding unassign op for which `should_replace_use` returns true.
442446
for (auto [original_result, fragment_result] :

shardy/dialect/mpmd/ir/utils.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,16 @@ limitations under the License.
4040

4141
namespace mlir::mpmd {
4242

43-
// Globsl sdy mesh name.
44-
constexpr StringRef kGlobalMeshName = "mesh";
45-
43+
// Global sdy mesh name.
44+
inline constexpr StringRef kGlobalMeshName = "mesh";
4645
// The function attribute that holds the SPMD mesh.
47-
constexpr StringRef kMeshShapeAttr = "mesh_shape";
46+
inline constexpr StringRef kMeshShapeAttr = "mesh_shape";
4847
// The function attribute that holds the MPMD topology.
49-
constexpr StringRef kTopologyAttr = "topology";
50-
48+
inline constexpr StringRef kTopologyAttr = "topology";
5149

5250
// The suffix of the mesh name for a CPU mesh.
5351
// LINT.IfChange
54-
constexpr StringRef kCpuMeshSuffix = "/cpu";
52+
inline constexpr StringRef kCpuMeshSuffix = "/cpu";
5553
// LINT.ThenChange(
5654
// https://github.com/openxla/shardy/blob/main/shardy/integrations/python/jax/mpmd/types.py
5755
// )
@@ -76,6 +74,9 @@ inline constexpr StringRef kRematAttributeName = "remat";
7674

7775
inline constexpr StringRef kJaxResultInfoAttr = "jax.result_info";
7876

77+
// The attribute that holds the list of pass names that inferred a fragment.
78+
inline constexpr StringRef kInferredByAttr = "mpmd.inferred_by";
79+
7980
template <typename... Args>
8081
std::string StrCat(Args&&... args) {
8182
std::string result;
@@ -261,6 +262,7 @@ SmallVector<MpmdDataflowEdge> GetMpmdDataflowEdges(func::FuncOp func_op);
261262
// `should_replace_use` returns true.
262263
FragmentOp WrapOpWithFragment(
263264
Operation* op, StringRef mesh_name, RewriterBase& rewriter,
265+
StringRef inferred_by,
264266
std::function<bool(OpOperand&)> should_replace_use = [](OpOperand&) {
265267
return true;
266268
});

shardy/dialect/mpmd/transforms/common/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ cc_library(
4242
"rule_based_merge.cc",
4343
"scheduler_preprocess.cc",
4444
"split_bwd_fragments.cc",
45+
"uniquify_and_merge_returns.cc",
4546
"uniquify_function_inputs_outputs.cc",
4647
"unroll_for_loops.cc",
4748
],

shardy/dialect/mpmd/transforms/common/merge_fragments.cc

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ limitations under the License.
2424
#include <vector>
2525

2626
#include "llvm/ADT/STLExtras.h"
27+
#include "llvm/ADT/SetVector.h"
2728
#include "llvm/ADT/StringRef.h"
2829
#include "llvm/Support/Debug.h"
2930
#include "llvm/Support/FormatVariadic.h"
@@ -182,6 +183,31 @@ std::optional<int> MergeCallCounters(FragmentOp producer_op,
182183
return std::nullopt;
183184
}
184185

186+
std::optional<ArrayAttr> MergeInferredByAttributes(FragmentOp producer_op,
187+
FragmentOp consumer_op) {
188+
ArrayAttr producer_inferred_by =
189+
producer_op->getAttrOfType<ArrayAttr>(kInferredByAttr);
190+
ArrayAttr consumer_inferred_by =
191+
consumer_op->getAttrOfType<ArrayAttr>(kInferredByAttr);
192+
193+
if (!producer_inferred_by && !consumer_inferred_by) {
194+
return std::nullopt;
195+
}
196+
197+
llvm::SetVector<Attribute> combined_inferred_by;
198+
if (producer_inferred_by) {
199+
combined_inferred_by.insert(producer_inferred_by.begin(),
200+
producer_inferred_by.end());
201+
}
202+
if (consumer_inferred_by) {
203+
combined_inferred_by.insert(consumer_inferred_by.begin(),
204+
consumer_inferred_by.end());
205+
}
206+
207+
IRRewriter rewriter(producer_op.getContext());
208+
return rewriter.getArrayAttr(combined_inferred_by.takeVector());
209+
}
210+
185211
// Returns a list of attributes that must be preserved in the merged fragment.
186212
// Note: origins are preserved by default and require no extra work.
187213
SmallVector<std::pair<StringRef, Attribute>> MergedAttributes(
@@ -194,6 +220,12 @@ SmallVector<std::pair<StringRef, Attribute>> MergedAttributes(
194220
attributes.emplace_back(kCallCounterAttrName,
195221
rewriter.getUI32IntegerAttr(*merged_call_count));
196222
}
223+
224+
if (std::optional<ArrayAttr> merged_inferred_by =
225+
MergeInferredByAttributes(producer_op, consumer_op)) {
226+
attributes.emplace_back(kInferredByAttr, *merged_inferred_by);
227+
}
228+
197229
return attributes;
198230
}
199231

shardy/dialect/mpmd/transforms/common/passes.td

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,36 @@ def UniquifyFunctionInputsOutputsPass :
461461
let dependentDialects = ["mlir::mpmd::MpmdDialect"];
462462
}
463463

464+
def UniquifyAndMergeReturnsPass :
465+
PassBase<"mpmd-uniquify-and-merge-returns", "DistributedFunctionPass"> {
466+
let summary = "Uniquifies any value returned multiple times or any block "
467+
"argument directly returned by the function.";
468+
let description = [{
469+
If a function returns the same value multiple times, creates multiple
470+
versions for that value, by creating a fragment assigned to that value's
471+
mesh which returns the value multiple times. After this pass, each return
472+
operand is unique. This is important to ensure that the respective results
473+
are allocated in different buffers, as in the following `jax.jit` example:
474+
475+
```python
476+
def f(x):
477+
y = x + x
478+
return y, y
479+
480+
z1, z2 = f(5)
481+
z1 += 1
482+
print(z1) ~~> 6
483+
print(z2) ~~> 5
484+
```
485+
486+
Similarly, if a function returns a block argument, this pass creates an
487+
identity fragment for that block argument, guaranteeing that values are
488+
passed by value to the function, not by reference.
489+
}];
490+
491+
let dependentDialects = ["mlir::mpmd::MpmdDialect"];
492+
}
493+
464494
def SchedulingUnitVerifierPass :
465495
PassBase<"mpmd-scheduling-units-verifier", "DistributedFunctionPass"> {
466496
let summary = "Verifies if the program contains the required scheduling units.";
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
// RUN: mpmd_opt %s -mpmd-uniquify-and-merge-returns -split-input-file 2>&1 | FileCheck %s
2+
3+
!mesh_1_tensor = !mpmd.mesh_tensor<"m1", tensor<4xf32>>
4+
!mesh_2_tensor = !mpmd.mesh_tensor<"m2", tensor<4xf32>>
5+
6+
// CHECK-LABEL: func @no_work_needed
7+
func.func @no_work_needed(%arg0: !mesh_1_tensor, %arg1: !mesh_2_tensor) -> (!mesh_1_tensor, !mesh_2_tensor) attributes {
8+
"topology"=#mpmd.topology<<"m1": <["x"=2]>>, <"m2": <["x"=2]>>>
9+
} {
10+
// CHECK-NEXT: %[[F1:.*]] = mpmd.fragment<mesh="m1", origin=["f1"]>
11+
// CHECK: %[[F2:.*]] = mpmd.fragment<mesh="m2", origin=["f2"]>
12+
// CHECK: return %[[F1]], %[[F2]]
13+
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0) (%arg2: tensor<4xf32>) {
14+
%1 = stablehlo.add %arg2, %arg2 : tensor<4xf32>
15+
mpmd.return %1 : tensor<4xf32>
16+
} : (!mesh_1_tensor) -> !mesh_1_tensor
17+
%1 = mpmd.fragment<mesh="m2", origin=["f2"]> (%arg1) (%arg2: tensor<4xf32>) {
18+
%1 = stablehlo.add %arg2, %arg2 : tensor<4xf32>
19+
mpmd.return %1 : tensor<4xf32>
20+
} : (!mesh_2_tensor) -> !mesh_2_tensor
21+
return %0, %1 : !mesh_1_tensor, !mesh_2_tensor
22+
}
23+
24+
25+
// CHECK-LABEL: func @single_mesh_one_return_operand
26+
func.func @single_mesh_one_return_operand(%arg0: !mesh_1_tensor) -> (!mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor) attributes {
27+
"topology"=#mpmd.topology<<"m1": <["x"=2]>>>
28+
} {
29+
// CHECK-NEXT: %[[F1:.*]]:3 = mpmd.fragment<mesh="m1", origin=["f1"]>
30+
// CHECK: %[[F2:.*]] = mpmd.fragment<mesh="m1", origin=["f2"]> (%[[F1]]#0)
31+
// CHECK: return %[[F2]], %[[F1]]#1, %[[F1]]#2
32+
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0) (%arg1: tensor<4xf32>) {
33+
%1 = stablehlo.add %arg1, %arg1 : tensor<4xf32>
34+
mpmd.return %1 : tensor<4xf32>
35+
} : (!mesh_1_tensor) -> !mesh_1_tensor
36+
%1 = mpmd.fragment<mesh="m1", origin=["f2"]> (%0) (%arg1: tensor<4xf32>) {
37+
mpmd.return %arg1 : tensor<4xf32>
38+
} : (!mesh_1_tensor) -> !mesh_1_tensor
39+
return %1, %0, %0 : !mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor
40+
}
41+
42+
// CHECK-LABEL: func @needs_fragment_for_m1_with_many_values
43+
func.func @needs_fragment_for_m1_with_many_values(%arg0: !mesh_1_tensor, %arg1: !mesh_2_tensor
44+
) -> (!mesh_2_tensor, !mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor) attributes {
45+
"topology"=#mpmd.topology<<"m1": <["x"=2]>>, <"m2": <["x"=2]>>>
46+
} {
47+
// CHECK-NEXT: %[[F1:.*]] = mpmd.fragment<mesh="m1", origin=["f1"]>
48+
// CHECK: %[[F2:.*]] = mpmd.fragment<mesh="m2", origin=["f2"]>
49+
// CHECK: %[[F3:.*]]:5 = mpmd.fragment<mesh="m1", origin=["f3"]> (%[[F1]], %arg0)
50+
// CHECK: return %[[F2]], %[[F3]]#0, %[[F3]]#2, %[[F3]]#1, %[[F3]]#3, %[[F3]]#4
51+
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0) (%arg2: tensor<4xf32>) {
52+
mpmd.return %arg2 : tensor<4xf32>
53+
} : (!mesh_1_tensor) -> !mesh_1_tensor
54+
%1 = mpmd.fragment<mesh="m2", origin=["f2"]> (%arg1) (%arg2: tensor<4xf32>) {
55+
mpmd.return %arg2 : tensor<4xf32>
56+
} : (!mesh_2_tensor) -> !mesh_2_tensor
57+
%2 = mpmd.fragment<mesh="m1", origin=["f3"]> (%arg0) (%arg2: tensor<4xf32>) {
58+
mpmd.return %arg2 : tensor<4xf32>
59+
} : (!mesh_1_tensor) -> !mesh_1_tensor
60+
return %1, %0, %2, %0, %2, %2 : !mesh_2_tensor, !mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor
61+
}
62+
63+
// CHECK-LABEL: func @needs_fragment_for_m1_and_m2
64+
func.func @needs_fragment_for_m1_and_m2(%arg0: !mesh_1_tensor, %arg1: !mesh_2_tensor
65+
) -> (!mesh_1_tensor, !mesh_2_tensor, !mesh_2_tensor, !mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor) attributes {
66+
"topology"=#mpmd.topology<<"m1": <["x"=2]>>, <"m2": <["x"=2]>>>
67+
} {
68+
// CHECK: %[[F1:.*]] = mpmd.fragment<mesh="m1", origin=["f1"]>
69+
// CHECK: %[[F2:.*]]:2 = mpmd.fragment<mesh="m2", origin=["f2"]>
70+
// CHECK: %[[F3:.*]]:4 = mpmd.fragment<mesh="m1", origin=["f3"]> (%[[F1]], %arg0)
71+
// CHECK: return %[[F3]]#0, %[[F2]]#0, %[[F2]]#1, %[[F3]]#2, %[[F3]]#1, %[[F3]]#3
72+
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0) (%arg2: tensor<4xf32>) {
73+
mpmd.return %arg2 : tensor<4xf32>
74+
} : (!mesh_1_tensor) -> !mesh_1_tensor
75+
%1 = mpmd.fragment<mesh="m2", origin=["f2"]> (%arg1) (%arg2: tensor<4xf32>) {
76+
mpmd.return %arg2 : tensor<4xf32>
77+
} : (!mesh_2_tensor) -> !mesh_2_tensor
78+
%2 = mpmd.fragment<mesh="m1", origin=["f3"]> (%arg0) (%arg2: tensor<4xf32>) {
79+
mpmd.return %arg2 : tensor<4xf32>
80+
} : (!mesh_1_tensor) -> !mesh_1_tensor
81+
return %0, %1, %1, %2, %0, %2 : !mesh_1_tensor, !mesh_2_tensor, !mesh_2_tensor, !mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor
82+
}
83+
84+
// -----
85+
86+
!dist_mesh_tensor = !mpmd.mesh_tensor<"m1", tensor<4xf32>, sharding=<@mesh, [{"x"}]>>
87+
88+
module {
89+
90+
// CHECK-LABEL: func @single_mesh_one_return_operand
91+
func.func @single_mesh_one_return_operand_with_global_view(%arg0: !dist_mesh_tensor) -> (!dist_mesh_tensor, !dist_mesh_tensor, !dist_mesh_tensor) attributes {
92+
"topology"=#mpmd.topology<<"m1": <["x"=2]>>>
93+
} {
94+
// CHECK-NEXT: %[[F1:.*]]:3 = mpmd.fragment<mesh="m1", origin=["f1"]>
95+
// CHECK: %[[F2:.*]] = mpmd.fragment<mesh="m1", origin=["f2"]> (%[[F1]]#0)
96+
// CHECK: return %[[F2]], %[[F1]]#1, %[[F1]]#2
97+
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0) (%arg1: tensor<4xf32>) {
98+
%1 = stablehlo.add %arg1, %arg1 : tensor<4xf32>
99+
mpmd.return %1 : tensor<4xf32>
100+
} : (!dist_mesh_tensor) -> !dist_mesh_tensor
101+
%1 = mpmd.fragment<mesh="m1", origin=["f2"]> (%0) (%arg1: tensor<4xf32>) {
102+
mpmd.return %arg1 : tensor<4xf32>
103+
} : (!dist_mesh_tensor) -> !dist_mesh_tensor
104+
return %1, %0, %0 : !dist_mesh_tensor, !dist_mesh_tensor, !dist_mesh_tensor
105+
}
106+
}
107+
108+
// -----
109+
110+
!mesh_tensor = !mpmd.mesh_tensor<"m", tensor<4xui32>, sharding=<@mesh, [{"x"}]>>
111+
112+
// CHECK-LABEL: func @f
113+
func.func @f(%arg0: !mesh_tensor) -> (!mesh_tensor, !mesh_tensor, !mesh_tensor)
114+
attributes {"topology"=#mpmd.topology<<"m": <["x"=2]>>>}
115+
{
116+
// CHECK-NEXT: %[[F1:.*]] = mpmd.fragment<mesh="m", origin=["f"]> (%arg0) (%arg1: tensor<4xui32>) {
117+
// CHECK-NEXT: return %arg1
118+
// CHECK-NEXT: }
119+
// CHECK-NEXT: %[[F2:.*]]:2 = mpmd.fragment<mesh="m", origin=[]> (%arg0) {mpmd.inferred_by = ["uniquify"]} (%arg1: tensor<4xui32>) {
120+
// CHECK-NEXT: return %arg1, %arg1
121+
// CHECK-NEXT: }
122+
// CHECK-NEXT: return %[[F2]]#0, %[[F1]], %[[F2]]#1
123+
%0 = mpmd.fragment<mesh="m", origin=["f"]>(%arg0) (%arg1: tensor<4xui32>) {
124+
mpmd.return %arg1 : tensor<4xui32>
125+
} : (!mesh_tensor) -> !mesh_tensor
126+
func.return %arg0, %0, %arg0 : !mesh_tensor, !mesh_tensor, !mesh_tensor
127+
}
128+
129+
// -----
130+
131+
!mesh_tensor = !mpmd.mesh_tensor<"m", tensor<4xui32>, sharding=<@mesh, [{"x"}]>>
132+
133+
// CHECK-LABEL: func @identity_function
134+
func.func @identity_function(%arg0: !mesh_tensor) -> !mesh_tensor
135+
attributes {"topology"=#mpmd.topology<<"m": <["x"=2]>>>}
136+
{
137+
// CHECK-NEXT: %[[F:.*]] = mpmd.fragment<mesh="m", origin=[]> (%arg0) {mpmd.inferred_by = ["uniquify"]} (%arg1: tensor<4xui32>) {
138+
// CHECK-NEXT: return %arg1
139+
// CHECK-NEXT: }
140+
// CHECK-NEXT: return %[[F]]
141+
func.return %arg0 : !mesh_tensor
142+
}

shardy/dialect/mpmd/transforms/common/test/uniquify_function_inputs_outputs_with_reshard.mlir

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func.func @single_mesh_one_return_operand(%arg0: !mesh_1_tensor) -> (!mesh_1_ten
2828
} {
2929
// CHECK-NEXT: %[[F1:.*]] = mpmd.fragment<mesh="m1", origin=["f1"]>
3030
// CHECK: %[[F2:.*]] = mpmd.fragment<mesh="m1", origin=["f2"]>
31-
// CHECK: %[[UF:.*]]:2 = mpmd.fragment<mesh="m1", origin=[]> (%[[F1]]) (%arg1: tensor<4xf32>) {
31+
// CHECK: %[[UF:.*]]:2 = mpmd.fragment<mesh="m1", origin=[]> (%[[F1]]) {mpmd.inferred_by = ["uniquify"]} (%arg1: tensor<4xf32>) {
3232
// CHECK: mpmd.return %arg1, %arg1 : tensor<4xf32>, tensor<4xf32>
3333
// CHECK: %[[F2]], %[[UF]]#0, %[[UF]]#1
3434
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0) (%arg1: tensor<4xf32>) {
@@ -49,7 +49,7 @@ func.func @needs_fragment_for_m1_with_many_values(%arg0: !mesh_1_tensor, %arg1:
4949
// CHECK-NEXT: %[[F1:.*]] = mpmd.fragment<mesh="m1", origin=["f1"]>
5050
// CHECK: %[[F2:.*]] = mpmd.fragment<mesh="m2", origin=["f2"]>
5151
// CHECK: %[[F3:.*]] = mpmd.fragment<mesh="m1", origin=["f3"]>
52-
// CHECK: %[[UF:.*]]:5 = mpmd.fragment<mesh="m1", origin=[]> (%[[F1]], %[[F3]]) (%[[A1:.*]]: tensor<4xf32>, %[[A2:.*]]: tensor<4xf32>)
52+
// CHECK: %[[UF:.*]]:5 = mpmd.fragment<mesh="m1", origin=[]> (%[[F1]], %[[F3]]) {mpmd.inferred_by = ["uniquify"]} (%[[A1:.*]]: tensor<4xf32>, %[[A2:.*]]: tensor<4xf32>)
5353
// CHECK-NEXT: mpmd.return %[[A1]], %[[A1]], %[[A2]], %[[A2]], %[[A2]]
5454
// CHECK-NEXT: }
5555
// CHECK-NEXT: return %[[F2]], %[[UF]]#0, %[[UF]]#2, %[[UF]]#1, %[[UF]]#3, %[[UF]]#4
@@ -70,8 +70,8 @@ func.func @needs_fragment_for_m1_and_m2(%arg0: !mesh_1_tensor, %arg1: !mesh_2_te
7070
) -> (!mesh_1_tensor, !mesh_2_tensor, !mesh_2_tensor, !mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor) attributes {
7171
"topology"=#mpmd.topology<<"m1": <["x"=2]>>, <"m2": <["x"=2]>>>
7272
} {
73-
// CHECK: %[[UF1:.*]]:4 = mpmd.fragment<mesh="m1", origin=[]>
74-
// CHECK: %[[UF2:.*]]:2 = mpmd.fragment<mesh="m2", origin=[]>
73+
// CHECK: %[[UF1:.*]]:4 = mpmd.fragment<mesh="m1", origin=[]> ({{.*}}) {mpmd.inferred_by = ["uniquify"]}
74+
// CHECK: %[[UF2:.*]]:2 = mpmd.fragment<mesh="m2", origin=[]> ({{.*}}) {mpmd.inferred_by = ["uniquify"]}
7575
// CHECK: return %[[UF1]]#0, %[[UF2]]#0, %[[UF2]]#1, %[[UF1]]#2, %[[UF1]]#1, %[[UF1]]#3
7676
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0) (%arg2: tensor<4xf32>) {
7777
mpmd.return %arg2 : tensor<4xf32>
@@ -97,7 +97,7 @@ func.func @single_mesh_one_return_operand_with_global_view(%arg0: !dist_mesh_ten
9797
} {
9898
// CHECK-NEXT: %[[F1:.*]] = mpmd.fragment<mesh="m1", origin=["f1"]>
9999
// CHECK: %[[F2:.*]] = mpmd.fragment<mesh="m1", origin=["f2"]>
100-
// CHECK: %[[UF:.*]]:2 = mpmd.fragment<mesh="m1", origin=[]> (%[[F1]]) (%arg1: tensor<4xf32>) {
100+
// CHECK: %[[UF:.*]]:2 = mpmd.fragment<mesh="m1", origin=[]> (%[[F1]]) {mpmd.inferred_by = ["uniquify"]} (%arg1: tensor<4xf32>) {
101101
// CHECK: mpmd.return %arg1, %arg1 : tensor<4xf32>, tensor<4xf32>
102102
// CHECK: %[[F2]], %[[UF]]#0, %[[UF]]#1
103103
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0) (%arg1: tensor<4xf32>) {
@@ -122,7 +122,7 @@ func.func @f(%arg0: !mesh_tensor) -> (!mesh_tensor, !mesh_tensor, !mesh_tensor)
122122
// CHECK-NEXT: %[[F1:.*]] = mpmd.fragment<mesh="m", origin=["f"]> (%arg0) (%arg1: tensor<4xui32>) {
123123
// CHECK-NEXT: return %arg1
124124
// CHECK-NEXT: }
125-
// CHECK-NEXT: %[[F2:.*]]:2 = mpmd.fragment<mesh="m", origin=[]> (%arg0) (%arg1: tensor<4xui32>) {
125+
// CHECK-NEXT: %[[F2:.*]]:2 = mpmd.fragment<mesh="m", origin=[]> (%arg0) {mpmd.inferred_by = ["uniquify"]} (%arg1: tensor<4xui32>) {
126126
// CHECK-NEXT: return %arg1, %arg1
127127
// CHECK-NEXT: }
128128
// CHECK-NEXT: return %[[F2]]#0, %[[F1]], %[[F2]]#1
@@ -140,7 +140,7 @@ func.func @f(%arg0: !mesh_tensor) -> (!mesh_tensor, !mesh_tensor, !mesh_tensor)
140140
func.func @identity_function(%arg0: !mesh_tensor) -> !mesh_tensor
141141
attributes {"topology"=#mpmd.topology<<"m": <["x"=2]>>>}
142142
{
143-
// CHECK-NEXT: %[[F:.*]] = mpmd.fragment<mesh="m", origin=[]> (%arg0) (%arg1: tensor<4xui32>) {
143+
// CHECK-NEXT: %[[F:.*]] = mpmd.fragment<mesh="m", origin=[]> (%arg0) {mpmd.inferred_by = ["uniquify"]} (%arg1: tensor<4xui32>) {
144144
// CHECK-NEXT: return %arg1
145145
// CHECK-NEXT: }
146146
// CHECK-NEXT: return %[[F]]

0 commit comments

Comments
 (0)