Skip to content

Commit ca590e2

Browse files
petebucopybara-github
authored andcommitted
Copy UniquifyFunctionInputsOutputsPass to UniquifyAndMergeReturnsPass
This is a preparatory step to modify the new pass to also merge returns. The original pass and tests are copied and renamed to allow coexistence. PiperOrigin-RevId: 908184894
1 parent 9f4effd commit ca590e2

5 files changed

Lines changed: 340 additions & 6 deletions

File tree

shardy/dialect/mpmd/transforms/common/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ cc_library(
4242
"rule_based_merge.cc",
4343
"scheduler_preprocess.cc",
4444
"split_bwd_fragments.cc",
45+
"uniquify_and_merge_returns.cc",
4546
"uniquify_function_inputs_outputs.cc",
4647
"unroll_for_loops.cc",
4748
],

shardy/dialect/mpmd/transforms/common/passes.td

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,36 @@ def UniquifyFunctionInputsOutputsPass :
461461
let dependentDialects = ["mlir::mpmd::MpmdDialect"];
462462
}
463463

464+
def UniquifyAndMergeReturnsPass :
465+
PassBase<"mpmd-uniquify-and-merge-returns", "DistributedFunctionPass"> {
466+
let summary = "Uniquifies any value returned multiple times or any block "
467+
"argument directly returned by the function.";
468+
let description = [{
469+
If a function returns the same value multiple times, creates multiple
470+
versions for that value, by creating a fragment assigned to that value's
471+
mesh which returns the value multiple times. After this pass, each return
472+
operand is unique. This is important to ensure that the respective results
473+
are allocated in different buffers, as in the following `jax.jit` example:
474+
475+
```python
476+
def f(x):
477+
y = x + x
478+
return y, y
479+
480+
z1, z2 = f(5)
481+
z1 += 1
482+
print(z1) ~~> 6
483+
print(z2) ~~> 5
484+
```
485+
486+
Similarly, if a function returns a block argument, this pass creates an
487+
identity fragment for that block argument, guaranteeing that values are
488+
passed by value to the function, not by reference.
489+
}];
490+
491+
let dependentDialects = ["mlir::mpmd::MpmdDialect"];
492+
}
493+
464494
def SchedulingUnitVerifierPass :
465495
PassBase<"mpmd-scheduling-units-verifier", "DistributedFunctionPass"> {
466496
let summary = "Verifies if the program contains the required scheduling units.";
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
// RUN: mpmd_opt %s -mpmd-uniquify-and-merge-returns -split-input-file 2>&1 | FileCheck %s
2+
3+
!mesh_1_tensor = !mpmd.mesh_tensor<"m1", tensor<4xf32>>
4+
!mesh_2_tensor = !mpmd.mesh_tensor<"m2", tensor<4xf32>>
5+
6+
// CHECK-LABEL: func @no_work_needed
7+
func.func @no_work_needed(%arg0: !mesh_1_tensor, %arg1: !mesh_2_tensor) -> (!mesh_1_tensor, !mesh_2_tensor) attributes {
8+
"topology"=#mpmd.topology<<"m1": <["x"=2]>>, <"m2": <["x"=2]>>>
9+
} {
10+
// CHECK-NEXT: %[[F1:.*]] = mpmd.fragment<mesh="m1", origin=["f1"]>
11+
// CHECK: %[[F2:.*]] = mpmd.fragment<mesh="m2", origin=["f2"]>
12+
// CHECK: return %[[F1]], %[[F2]]
13+
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0) (%arg2: tensor<4xf32>) {
14+
%1 = stablehlo.add %arg2, %arg2 : tensor<4xf32>
15+
mpmd.return %1 : tensor<4xf32>
16+
} : (!mesh_1_tensor) -> !mesh_1_tensor
17+
%1 = mpmd.fragment<mesh="m2", origin=["f2"]> (%arg1) (%arg2: tensor<4xf32>) {
18+
%1 = stablehlo.add %arg2, %arg2 : tensor<4xf32>
19+
mpmd.return %1 : tensor<4xf32>
20+
} : (!mesh_2_tensor) -> !mesh_2_tensor
21+
return %0, %1 : !mesh_1_tensor, !mesh_2_tensor
22+
}
23+
24+
25+
// CHECK-LABEL: func @single_mesh_one_return_operand
26+
func.func @single_mesh_one_return_operand(%arg0: !mesh_1_tensor) -> (!mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor) attributes {
27+
"topology"=#mpmd.topology<<"m1": <["x"=2]>>>
28+
} {
29+
// CHECK-NEXT: %[[F1:.*]]:3 = mpmd.fragment<mesh="m1", origin=["f1"]>
30+
// CHECK: %[[F2:.*]] = mpmd.fragment<mesh="m1", origin=["f2"]> (%[[F1]]#0)
31+
// CHECK: return %[[F2]], %[[F1]]#1, %[[F1]]#2
32+
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0) (%arg1: tensor<4xf32>) {
33+
%1 = stablehlo.add %arg1, %arg1 : tensor<4xf32>
34+
mpmd.return %1 : tensor<4xf32>
35+
} : (!mesh_1_tensor) -> !mesh_1_tensor
36+
%1 = mpmd.fragment<mesh="m1", origin=["f2"]> (%0) (%arg1: tensor<4xf32>) {
37+
mpmd.return %arg1 : tensor<4xf32>
38+
} : (!mesh_1_tensor) -> !mesh_1_tensor
39+
return %1, %0, %0 : !mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor
40+
}
41+
42+
// CHECK-LABEL: func @needs_fragment_for_m1_with_many_values
43+
func.func @needs_fragment_for_m1_with_many_values(%arg0: !mesh_1_tensor, %arg1: !mesh_2_tensor
44+
) -> (!mesh_2_tensor, !mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor) attributes {
45+
"topology"=#mpmd.topology<<"m1": <["x"=2]>>, <"m2": <["x"=2]>>>
46+
} {
47+
// CHECK-NEXT: %[[F1:.*]] = mpmd.fragment<mesh="m1", origin=["f1"]>
48+
// CHECK: %[[F2:.*]] = mpmd.fragment<mesh="m2", origin=["f2"]>
49+
// CHECK: %[[F3:.*]]:5 = mpmd.fragment<mesh="m1", origin=["f3"]> (%[[F1]], %arg0)
50+
// CHECK: return %[[F2]], %[[F3]]#0, %[[F3]]#2, %[[F3]]#1, %[[F3]]#3, %[[F3]]#4
51+
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0) (%arg2: tensor<4xf32>) {
52+
mpmd.return %arg2 : tensor<4xf32>
53+
} : (!mesh_1_tensor) -> !mesh_1_tensor
54+
%1 = mpmd.fragment<mesh="m2", origin=["f2"]> (%arg1) (%arg2: tensor<4xf32>) {
55+
mpmd.return %arg2 : tensor<4xf32>
56+
} : (!mesh_2_tensor) -> !mesh_2_tensor
57+
%2 = mpmd.fragment<mesh="m1", origin=["f3"]> (%arg0) (%arg2: tensor<4xf32>) {
58+
mpmd.return %arg2 : tensor<4xf32>
59+
} : (!mesh_1_tensor) -> !mesh_1_tensor
60+
return %1, %0, %2, %0, %2, %2 : !mesh_2_tensor, !mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor
61+
}
62+
63+
// CHECK-LABEL: func @needs_fragment_for_m1_and_m2
64+
func.func @needs_fragment_for_m1_and_m2(%arg0: !mesh_1_tensor, %arg1: !mesh_2_tensor
65+
) -> (!mesh_1_tensor, !mesh_2_tensor, !mesh_2_tensor, !mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor) attributes {
66+
"topology"=#mpmd.topology<<"m1": <["x"=2]>>, <"m2": <["x"=2]>>>
67+
} {
68+
// CHECK: %[[F1:.*]] = mpmd.fragment<mesh="m1", origin=["f1"]>
69+
// CHECK: %[[F2:.*]]:2 = mpmd.fragment<mesh="m2", origin=["f2"]>
70+
// CHECK: %[[F3:.*]]:4 = mpmd.fragment<mesh="m1", origin=["f3"]> (%[[F1]], %arg0)
71+
// CHECK: return %[[F3]]#0, %[[F2]]#0, %[[F2]]#1, %[[F3]]#2, %[[F3]]#1, %[[F3]]#3
72+
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0) (%arg2: tensor<4xf32>) {
73+
mpmd.return %arg2 : tensor<4xf32>
74+
} : (!mesh_1_tensor) -> !mesh_1_tensor
75+
%1 = mpmd.fragment<mesh="m2", origin=["f2"]> (%arg1) (%arg2: tensor<4xf32>) {
76+
mpmd.return %arg2 : tensor<4xf32>
77+
} : (!mesh_2_tensor) -> !mesh_2_tensor
78+
%2 = mpmd.fragment<mesh="m1", origin=["f3"]> (%arg0) (%arg2: tensor<4xf32>) {
79+
mpmd.return %arg2 : tensor<4xf32>
80+
} : (!mesh_1_tensor) -> !mesh_1_tensor
81+
return %0, %1, %1, %2, %0, %2 : !mesh_1_tensor, !mesh_2_tensor, !mesh_2_tensor, !mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor
82+
}
83+
84+
// -----
85+
86+
!dist_mesh_tensor = !mpmd.mesh_tensor<"m1", tensor<4xf32>, sharding=<@mesh, [{"x"}]>>
87+
88+
module {
89+
90+
// CHECK-LABEL: func @single_mesh_one_return_operand
91+
func.func @single_mesh_one_return_operand_with_global_view(%arg0: !dist_mesh_tensor) -> (!dist_mesh_tensor, !dist_mesh_tensor, !dist_mesh_tensor) attributes {
92+
"topology"=#mpmd.topology<<"m1": <["x"=2]>>>
93+
} {
94+
// CHECK-NEXT: %[[F1:.*]]:3 = mpmd.fragment<mesh="m1", origin=["f1"]>
95+
// CHECK: %[[F2:.*]] = mpmd.fragment<mesh="m1", origin=["f2"]> (%[[F1]]#0)
96+
// CHECK: return %[[F2]], %[[F1]]#1, %[[F1]]#2
97+
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0) (%arg1: tensor<4xf32>) {
98+
%1 = stablehlo.add %arg1, %arg1 : tensor<4xf32>
99+
mpmd.return %1 : tensor<4xf32>
100+
} : (!dist_mesh_tensor) -> !dist_mesh_tensor
101+
%1 = mpmd.fragment<mesh="m1", origin=["f2"]> (%0) (%arg1: tensor<4xf32>) {
102+
mpmd.return %arg1 : tensor<4xf32>
103+
} : (!dist_mesh_tensor) -> !dist_mesh_tensor
104+
return %1, %0, %0 : !dist_mesh_tensor, !dist_mesh_tensor, !dist_mesh_tensor
105+
}
106+
}
107+
108+
// -----
109+
110+
!mesh_tensor = !mpmd.mesh_tensor<"m", tensor<4xui32>, sharding=<@mesh, [{"x"}]>>
111+
112+
// CHECK-LABEL: func @f
113+
func.func @f(%arg0: !mesh_tensor) -> (!mesh_tensor, !mesh_tensor, !mesh_tensor)
114+
attributes {"topology"=#mpmd.topology<<"m": <["x"=2]>>>}
115+
{
116+
// CHECK-NEXT: %[[F1:.*]] = mpmd.fragment<mesh="m", origin=["f"]> (%arg0) (%arg1: tensor<4xui32>) {
117+
// CHECK-NEXT: return %arg1
118+
// CHECK-NEXT: }
119+
// CHECK-NEXT: %[[F2:.*]]:2 = mpmd.fragment<mesh="m", origin=[]> (%arg0) (%arg1: tensor<4xui32>) {
120+
// CHECK-NEXT: return %arg1, %arg1
121+
// CHECK-NEXT: }
122+
// CHECK-NEXT: return %[[F2]]#0, %[[F1]], %[[F2]]#1
123+
%0 = mpmd.fragment<mesh="m", origin=["f"]>(%arg0) (%arg1: tensor<4xui32>) {
124+
mpmd.return %arg1 : tensor<4xui32>
125+
} : (!mesh_tensor) -> !mesh_tensor
126+
func.return %arg0, %0, %arg0 : !mesh_tensor, !mesh_tensor, !mesh_tensor
127+
}
128+
129+
// -----
130+
131+
!mesh_tensor = !mpmd.mesh_tensor<"m", tensor<4xui32>, sharding=<@mesh, [{"x"}]>>
132+
133+
// CHECK-LABEL: func @identity_function
134+
func.func @identity_function(%arg0: !mesh_tensor) -> !mesh_tensor
135+
attributes {"topology"=#mpmd.topology<<"m": <["x"=2]>>>}
136+
{
137+
// CHECK-NEXT: %[[F:.*]] = mpmd.fragment<mesh="m", origin=[]> (%arg0) (%arg1: tensor<4xui32>) {
138+
// CHECK-NEXT: return %arg1
139+
// CHECK-NEXT: }
140+
// CHECK-NEXT: return %[[F]]
141+
func.return %arg0 : !mesh_tensor
142+
}
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
/* Copyright 2025 The MPMD Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include <cstdint>
17+
#include <utility>
18+
19+
#include "llvm/ADT/MapVector.h"
20+
#include "llvm/ADT/SmallVector.h"
21+
#include "mlir/Dialect/Func/IR/FuncOps.h"
22+
#include "mlir/IR/Builders.h"
23+
#include "mlir/IR/BuiltinAttributes.h"
24+
#include "mlir/IR/MLIRContext.h"
25+
#include "mlir/IR/PatternMatch.h"
26+
#include "mlir/IR/Types.h"
27+
#include "mlir/IR/Value.h"
28+
#include "mlir/Support/LLVM.h"
29+
#include "shardy/common/logging.h"
30+
#include "shardy/dialect/mpmd/ir/dialect.h"
31+
#include "shardy/dialect/mpmd/ir/utils.h"
32+
#include "shardy/dialect/mpmd/transforms/common/passes.h" // IWYU pragma: keep
33+
#include "shardy/dialect/mpmd/transforms/common/utils.h"
34+
#include "shardy/dialect/sdy/ir/dialect.h"
35+
36+
namespace mlir::mpmd {
37+
38+
#define GEN_PASS_DEF_UNIQUIFYANDMERGERETURNSPASS
39+
#include "shardy/dialect/mpmd/transforms/common/passes.h.inc"
40+
41+
namespace {
42+
43+
using ValueToReturnIndices = llvm::MapVector<Value, SmallVector<int64_t>>;
44+
45+
void CreateReturnFragmentForMesh(StringRef mesh_name, Operation* return_op,
46+
ValueToReturnIndices& value_to_return_indices,
47+
OpBuilder& builder) {
48+
// We remove any entries that require no work, in order to avoid too many
49+
// checks.
50+
value_to_return_indices.remove_if([](const auto& it) {
51+
if (it.second.size() == 1) {
52+
Value v = it.first;
53+
return !isa<BlockArgument>(v);
54+
}
55+
return it.second.empty();
56+
});
57+
58+
builder.setInsertionPoint(return_op);
59+
SmallVector<Value> fragment_operands;
60+
fragment_operands.reserve(value_to_return_indices.size());
61+
SmallVector<Type> fragment_return_types;
62+
for (const auto& [value, return_indices] : value_to_return_indices) {
63+
fragment_operands.push_back(value);
64+
fragment_return_types.insert(fragment_return_types.end(),
65+
return_indices.size(),
66+
cast<MeshTensorType>(value.getType()));
67+
}
68+
69+
if (fragment_operands.empty()) {
70+
return;
71+
}
72+
73+
auto loc = return_op->getLoc();
74+
auto fragment_op = FragmentOp::create(
75+
builder, loc, fragment_return_types, fragment_operands,
76+
/*user_origin=*/ArrayAttr::get(builder.getContext(), {}),
77+
/*mesh_name=*/mesh_name, /*stage_id=*/IntegerAttr());
78+
Block& fragment_block = fragment_op.getRegion().emplaceBlock();
79+
80+
SmallVector<Value> returned_values;
81+
returned_values.reserve(fragment_return_types.size());
82+
// The index of the fragment result that we should use to replace the
83+
// function return op operand.
84+
int fragment_result_index = 0;
85+
sdy::MeshAttr mesh_attr = GetMeshOrFail(fragment_op, mesh_name);
86+
for (const auto& [value, return_indices] : value_to_return_indices) {
87+
// Add a single block argument for this value and return it as many times
88+
// as it's used.
89+
returned_values.insert(
90+
returned_values.end(), return_indices.size(),
91+
fragment_block.addArgument(
92+
GetGlobalTensorTypeFromMeshType(value, mesh_attr), value.getLoc()));
93+
94+
for (int64_t index : return_indices) {
95+
return_op->setOperand(index,
96+
fragment_op->getResult(fragment_result_index++));
97+
}
98+
}
99+
auto block_builder = OpBuilder::atBlockEnd(&fragment_block);
100+
ReturnOp::create(block_builder, loc, returned_values);
101+
102+
Operation* latest_producer = nullptr;
103+
for (Value v : fragment_operands) {
104+
if (Operation* op = v.getDefiningOp()) {
105+
if (!latest_producer || latest_producer->isBeforeInBlock(op)) {
106+
latest_producer = op;
107+
}
108+
}
109+
}
110+
111+
if (latest_producer) {
112+
if (auto producer_fragment = dyn_cast<FragmentOp>(latest_producer)) {
113+
if (producer_fragment.getMeshName() == mesh_name) {
114+
fragment_op->moveAfter(producer_fragment);
115+
IRRewriter rewriter(builder.getContext());
116+
MergeRegionOps(
117+
producer_fragment, fragment_op, rewriter,
118+
/*num_static_args=*/0, /*replace_producer_use_in_consumer_block=*/
119+
[](OpOperand&, Value) {
120+
SDY_CHECK(false) << "Fragment ops shouldn't have free variables";
121+
},
122+
GetFragmentOriginUnion(producer_fragment, fragment_op, rewriter),
123+
producer_fragment.getMeshNameAttr(),
124+
/*stage_id=*/producer_fragment.getStageIdAttr());
125+
}
126+
}
127+
}
128+
}
129+
130+
class UniquifyAndMergeReturnsPass
131+
: public impl::UniquifyAndMergeReturnsPassBase<
132+
UniquifyAndMergeReturnsPass> {
133+
using UniquifyAndMergeReturnsPassBase::UniquifyAndMergeReturnsPassBase;
134+
135+
private:
136+
void runOnFunc(func::FuncOp func_op) override {
137+
if (!IsMpmdFunction(func_op)) {
138+
// This is not the main function. Do nothing.
139+
return;
140+
}
141+
142+
Operation* return_op = func_op.getBody().front().getTerminator();
143+
// value_to_return_indices_per_mesh[mesh_name] = value_to_return_indices
144+
// where value_to_return_indices[v] contains a sequence of the indices in
145+
// return op where v is used.
146+
llvm::MapVector<StringRef, ValueToReturnIndices>
147+
value_to_return_indices_per_mesh;
148+
for (OpOperand& operand : return_op->getOpOperands()) {
149+
auto mesh_type = dyn_cast<MeshTensorType>(operand.get().getType());
150+
SDY_CHECK(mesh_type);
151+
StringRef mesh_name = mesh_type.getMeshName();
152+
value_to_return_indices_per_mesh[mesh_name][operand.get()].push_back(
153+
operand.getOperandNumber());
154+
}
155+
156+
OpBuilder builder(&getContext());
157+
for (auto& [mesh_name, value_to_return_indices] :
158+
value_to_return_indices_per_mesh) {
159+
CreateReturnFragmentForMesh(mesh_name, return_op, value_to_return_indices,
160+
builder);
161+
}
162+
}
163+
};
164+
165+
} // namespace
166+
} // namespace mlir::mpmd

shardy/dialect/mpmd/transforms/export/export_pipeline.cc

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,7 @@ void addExportPipeline(OpPassManager& pm, const ExportOptions& options) {
8484
// Must be applied after the last -mpmd-fragment-dedup, as it may add
8585
// duplicated fragment results and after -canonicalize, as it may add
8686
// identity fragments, which would be canonicalized away.
87-
pm.addNestedPass<FuncOp>(createUniquifyFunctionInputsOutputsPass());
88-
89-
// The fragments created by the pass above maybe slowdown compilation (more
90-
// fragments to compile) and may cause performance regressions. Thus, we merge
91-
// them with other fragments.
92-
pm.addNestedPass<FuncOp>(createMergeInferredFragmentsPass());
87+
pm.addNestedPass<FuncOp>(createUniquifyAndMergeReturnsPass());
9388

9489
// Mark each fragment with the inputs and outputs which are offloaded to host
9590
// memory.

0 commit comments

Comments
 (0)