Skip to content

Commit e31984f

Browse files
[mpmd] Move pre rule-based scheduling/merging passes to import pipeline
Rule generation will occur after import, so further changes to fragments should not occur after import until rule-based scheduling/merging. This CL should be a no-op for non rule-based merge pass users. PiperOrigin-RevId: 803073873
1 parent edb927d commit e31984f

12 files changed

Lines changed: 250 additions & 183 deletions

File tree

shardy/dialect/mpmd/transforms/common/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,15 @@ cc_library(
3939
"merge_transfers.cc",
4040
"remove_transfer_cycles.cc",
4141
"rule_based_merge.cc",
42+
"scheduler_preprocess.cc",
4243
"split_bwd_fragments.cc",
4344
"uniquify_function_inputs_outputs.cc",
4445
"unroll_for_loops.cc",
4546
],
4647
hdrs = [
4748
"merge_fragments.h",
4849
"passes.h",
50+
"scheduler_preprocess.h",
4951
],
5052
deps = [
5153
":distributed_function_pass",

shardy/dialect/mpmd/transforms/common/passes.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,3 +486,21 @@ def UniquifyFunctionInputsOutputsPass :
486486

487487
let dependentDialects = ["mlir::mpmd::MpmdDialect"];
488488
}
489+
490+
def SchedulingUnitVerifierPass :
491+
PassBase<"mpmd-scheduling-units-verifier", "DistributedFunctionPass"> {
492+
let summary = "Verifies if the program contains the required scheduling units.";
493+
}
494+
495+
// TODO: b/378099938 - Remove this pass once we have a better way to handle
496+
// transfers while merging fragments. We need this now because having a transfer
497+
// in between two fragments prevents the merge pass from merging them.
498+
def MoveTransfersToProducerPass :
499+
PassBase<"mpmd-move-transfers-to-producer", "DistributedFunctionPass"> {
500+
let summary = "Moves transfers next to their producers.";
501+
let description = [{
502+
Moves transfers next to their producers: if the operand is a block argument,
503+
move the transfer to the beginning of the block, otherwise move it after the
504+
defining op.
505+
}];
506+
}
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
/* Copyright 2025 The MPMD Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "shardy/dialect/mpmd/transforms/common/scheduler_preprocess.h"
17+
18+
#include <algorithm>
19+
#include <cstdint>
20+
21+
#include "mlir/Dialect/Func/IR/FuncOps.h"
22+
#include "mlir/IR/PatternMatch.h"
23+
#include "mlir/IR/Value.h"
24+
#include "mlir/Pass/PassManager.h"
25+
#include "mlir/Support/LLVM.h"
26+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27+
#include "mlir/Transforms/Passes.h"
28+
#include "shardy/common/logging.h"
29+
#include "shardy/dialect/mpmd/ir/dialect.h"
30+
#include "shardy/dialect/mpmd/ir/utils.h"
31+
#include "shardy/dialect/mpmd/transforms/common/passes.h"
32+
#include "shardy/dialect/mpmd/transforms/optimize/utils.h"
33+
34+
namespace mlir::mpmd {
35+
36+
#define GEN_PASS_DEF_SCHEDULINGUNITVERIFIERPASS
37+
#define GEN_PASS_DEF_MOVETRANSFERSTOPRODUCERPASS
38+
#include "shardy/dialect/mpmd/transforms/common/passes.h.inc"
39+
40+
namespace {
41+
42+
using ::mlir::func::FuncOp;
43+
44+
// Returns the number of microbatches in the program.
45+
// TODO(jupvfranco): This code assumes that microbatching is zero- or one-
46+
// based. Can we generalize this?
47+
uint32_t GetNumMicrobatches(FuncOp func_op) {
48+
uint32_t max_call_counter = 0;
49+
bool is_zero_based = false;
50+
func_op.walk([&max_call_counter, &is_zero_based](FragmentOp fragment) {
51+
if (auto call_counter = TryToFindCallCounter(fragment)) {
52+
if (*call_counter == 0) {
53+
is_zero_based = true;
54+
}
55+
max_call_counter = std::max(max_call_counter, *call_counter);
56+
}
57+
});
58+
return max_call_counter + (is_zero_based ? 1 : 0);
59+
}
60+
61+
class SchedulingUnitVerifierPass
62+
: public impl::SchedulingUnitVerifierPassBase<SchedulingUnitVerifierPass> {
63+
using SchedulingUnitVerifierPassBase::SchedulingUnitVerifierPassBase;
64+
65+
private:
66+
void runOnFunc(FuncOp func_op) override {
67+
if (!IsMpmdFunction(func_op)) {
68+
return;
69+
}
70+
71+
const uint32_t num_microbatches = GetNumMicrobatches(func_op);
72+
if (num_microbatches == 0) {
73+
SDY_LOG(WARNING)
74+
<< "Function is not microbatched and therefore cannot be "
75+
"rescheduled.";
76+
// We exit instead of emitting an error so that this won't affect init
77+
// functions that are typically not microbatched.
78+
return;
79+
}
80+
81+
// Check if every mesh has `num_microbatches` scheduling units, half of them
82+
// forward and the other half backward.
83+
// TODO(jupvfranco): This works for the simple schedules we support now, but
84+
// we need to revisit this logic.
85+
for (NamedMeshAttr mesh : GetSchedulableMeshes(func_op)) {
86+
int count_fwd = 0, count_bwd = 0;
87+
for (Operation& op : func_op.getOps()) {
88+
auto fragment = dyn_cast<FragmentOp>(&op);
89+
if (!fragment || !IsSchedulingUnit(fragment) ||
90+
fragment.getMeshName() != mesh.getName()) {
91+
continue;
92+
}
93+
if (*TryToFindSingleTransposeCount(fragment) == 0) {
94+
count_fwd++;
95+
} else {
96+
count_bwd++;
97+
}
98+
}
99+
if (count_fwd != num_microbatches) {
100+
func_op.emitWarning("Number of forward scheduling units in mesh ")
101+
<< mesh.getName() << " does not match expected number for "
102+
<< num_microbatches << " microbatches. Got " << count_fwd << ".";
103+
}
104+
105+
if (count_bwd != num_microbatches) {
106+
func_op.emitWarning("Number of backward scheduling units in mesh ")
107+
<< mesh.getName() << " does not match expected number for "
108+
<< num_microbatches << " microbatches. Got " << count_bwd << ".";
109+
}
110+
}
111+
}
112+
};
113+
114+
class MoveTransfersToProducerPass
115+
: public impl::MoveTransfersToProducerPassBase<
116+
MoveTransfersToProducerPass> {
117+
using MoveTransfersToProducerPassBase::MoveTransfersToProducerPassBase;
118+
119+
private:
120+
void runOnFunc(FuncOp func) override {
121+
IRRewriter rewriter(func.getContext());
122+
func.walk([&](TransferOp transfer) {
123+
if (auto arg = dyn_cast<BlockArgument>(transfer.getOperand())) {
124+
rewriter.moveOpBefore(transfer, arg.getOwner(),
125+
arg.getOwner()->begin());
126+
} else {
127+
rewriter.moveOpAfter(transfer, transfer.getOperand().getDefiningOp());
128+
}
129+
});
130+
}
131+
};
132+
133+
} // namespace
134+
135+
void AddSchedulingPreprocessingPasses(OpPassManager& pm,
136+
bool split_bwd_fragments,
137+
bool verify_schedule_units) {
138+
// The following seems like a good thing to always do, to keep the module
139+
// more tidy and merged, even if we are not going to actually do any
140+
// scheduling.
141+
// Move transfers to right after their producers. Without this pass, if we
142+
// have a producer fragment followed by transfers, then a consumer fragment,
143+
// even if the operands of the transfers are from a different producer
144+
// fragment, we are not able to merge the producer and consumer fragments.
145+
// This pass moves the transfers to right after the producer, which allows
146+
// the merge pass to do its job.
147+
pm.addNestedPass<FuncOp>(createMoveTransfersToProducerPass());
148+
pm.addNestedPass<FuncOp>(
149+
createMergeUserDefinedFragmentsIntoSchedulingUnitsPass());
150+
if (verify_schedule_units) {
151+
pm.addNestedPass<FuncOp>(createSchedulingUnitVerifierPass());
152+
}
153+
154+
// TODO(dvytin): Run split_bwd_fragments independently of the schedule.
155+
//
156+
// Furthermore, we now do the split after verification, which ensures that
157+
// the generic verification code we have still works. But we should consider
158+
// defining schedule-specific verification conditions (and even passes to
159+
// prepare the module for a given schedule.)
160+
// TODO(dvytin): Investigate how to define schedule-specific verification.
161+
if (split_bwd_fragments) {
162+
pm.addNestedPass<FuncOp>(createSplitBwdFragmentsPass());
163+
// TODO(jupvfranco): Do we really need canonicalizations here? Tests seem to
164+
// fail without it.
165+
pm.addPass(createCanonicalizerPass(
166+
GreedyRewriteConfig().setRegionSimplificationLevel(
167+
GreedySimplifyRegionLevel::Disabled)));
168+
pm.addNestedPass<FuncOp>(createFragmentDcePass());
169+
}
170+
}
171+
172+
} // namespace mlir::mpmd
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/* Copyright 2025 The MPMD Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef SHARDY_DIALECT_MPMD_TRANSFORMS_COMMON_SCHEDULER_PREPROCESS_H_
17+
#define SHARDY_DIALECT_MPMD_TRANSFORMS_COMMON_SCHEDULER_PREPROCESS_H_
18+
19+
#include "mlir/Pass/PassManager.h"
20+
21+
namespace mlir::mpmd {
22+
23+
// Adds all passes needed for pipeline scheduling preprocessing. This includes
24+
// merge of fragments into scheduling units and verification of scheduling
25+
// units.
26+
//
27+
// When `split_bwd_fragments` is true, then we split backward fragments into
28+
// a fragment whose results are transferred, and one that isn't. This is so that
29+
// we can execute the transfers earlier (e.g. as per Near-Zero Bubble
30+
// Pipeline).
31+
void AddSchedulingPreprocessingPasses(mlir::OpPassManager& pm,
32+
bool split_bwd_fragments,
33+
bool verify_schedule_units);
34+
35+
} // namespace mlir::mpmd
36+
37+
#endif // SHARDY_DIALECT_MPMD_TRANSFORMS_COMMON_SCHEDULER_PREPROCESS_H_

shardy/dialect/mpmd/transforms/import/import_pipeline.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ limitations under the License.
2424
#include "mlir/Transforms/Passes.h"
2525
#include "shardy/dialect/mpmd/transforms/common/merge_fragments.h"
2626
#include "shardy/dialect/mpmd/transforms/common/passes.h"
27+
#include "shardy/dialect/mpmd/transforms/common/scheduler_preprocess.h"
2728
#include "shardy/dialect/mpmd/transforms/import/infer_mesh_assignment.h"
2829
#include "shardy/dialect/mpmd/transforms/import/mesh_assignment_map.h"
2930
#include "shardy/dialect/mpmd/transforms/import/passes.h"
@@ -143,6 +144,20 @@ void addImportPipeline(OpPassManager& pm, ImportOptions options) {
143144
// Thus, we don't apply canonicalization again.
144145
pm.addNestedPass<FuncOp>(createFragmentDedupPass());
145146
pm.addNestedPass<FuncOp>(createFragmentDcePass());
147+
148+
// Apply optimization passes that modify fragments so fragments are stable
149+
// before rule-based merging/scheduling in the partition pipeline.
150+
// Apply as many optimizations as possible before inlining.
151+
pm.addNestedPass<FuncOp>(createRemoveTransferCyclesPass());
152+
AddCallInliningRelatedPasses(pm);
153+
// Merge any inferred fragments with user-defined fragments that could not be
154+
// merged before because of CallOps.
155+
if (!options.mergeAfterScheduling) {
156+
pm.addNestedPass<FuncOp>(createMergeInferredFragmentsPass());
157+
}
158+
// Merge fragments into scheduling units.
159+
AddSchedulingPreprocessingPasses(pm, options.splitBwdFragments,
160+
options.verifyScheduleUnits);
146161
}
147162

148163
namespace {

shardy/dialect/mpmd/transforms/import/passes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ struct ImportOptions {
6262
InferMeshOptions inferMeshOptions;
6363
// Enable heterogeneous meshes.
6464
bool enableHeterogeneousMeshes = false;
65+
// Whether to split backward fragments.
66+
bool splitBwdFragments = false;
67+
// Whether to verify if merging created the right number of scheduling units.
68+
bool verifyScheduleUnits = false;
6569
};
6670

6771
// Adds the standard set of passes to import an MPMD program with a fixed mesh

shardy/dialect/mpmd/transforms/optimize/optimize_pipeline.cc

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,6 @@ namespace mlir::mpmd {
3131
using ::mlir::func::FuncOp;
3232

3333
void addOptimizePipeline(OpPassManager& pm, OptimizeOptions options) {
34-
// Apply as many optimizations as possible before inlining.
35-
pm.addNestedPass<FuncOp>(createRemoveTransferCyclesPass());
36-
37-
// TODO(jupvfranco): consider moving inlining to import.
38-
AddCallInliningRelatedPasses(pm);
39-
// Merge any inferred fragments with user-defined fragments that could not be
40-
// merged before because of CallOps.
41-
if (!options.mergeAfterScheduling) {
42-
pm.addNestedPass<FuncOp>(createMergeInferredFragmentsPass());
43-
}
44-
4534
// Merge fragments according to the user-specified rules. Do this before other
4635
// merge passes since those modify the origins of fragments, invalidating the
4736
// rules.
@@ -50,10 +39,7 @@ void addOptimizePipeline(OpPassManager& pm, OptimizeOptions options) {
5039
RuleBasedMergePassOptions{std::move(options.fragmentMergeRules)}));
5140
}
5241

53-
// Adds all pipeline scheduling related passes.
54-
// Merge fragments into scheduling units.
55-
AddSchedulingPreprocessingPasses(pm, options.splitBwdFragments,
56-
options.verifyScheduleUnits);
42+
// Adds pipeline scheduling pass.
5743
AddSchedulingPass(pm, options.pipelineSchedule);
5844

5945
// The remat passes will run after inlining the call ops and scheduling.

shardy/dialect/mpmd/transforms/optimize/passes.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,6 @@ struct OptimizeOptions {
4444
SmallVector<FragmentMergeRule> fragmentMergeRules;
4545
// Whether to merge inferred fragments only after scheduling.
4646
bool mergeAfterScheduling = false;
47-
// Whether to split backward fragments.
48-
bool splitBwdFragments = false;
49-
// Whether to verify if merging created the right number of scheduling units.
50-
bool verifyScheduleUnits = false;
5147
// Whether to identify matching forward and backward fragments and clone the
5248
// forward fragment immediately.
5349
bool applyFragmentRemat = false;

shardy/dialect/mpmd/transforms/optimize/passes.td

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -67,21 +67,3 @@ def PipelineSchedulerPass :
6767
"as follows: `builtin:<schedule-as-string>`.">
6868
];
6969
}
70-
71-
def SchedulingUnitVerifierPass :
72-
PassBase<"mpmd-scheduling-units-verifier", "DistributedFunctionPass"> {
73-
let summary = "Verifies if the program contains the required scheduling units.";
74-
}
75-
76-
// TODO: b/378099938 - Remove this pass once we have a better way to handle
77-
// transfers while merging fragments. We need this now because having a transfer
78-
// in between two fragments prevents the merge pass from merging them.
79-
def MoveTransfersToProducerPass :
80-
PassBase<"mpmd-move-transfers-to-producer", "DistributedFunctionPass"> {
81-
let summary = "Moves transfers next to their producers.";
82-
let description = [{
83-
Moves transfers next to their producers: if the operand is a block argument,
84-
move the transfer to the beginning of the block, otherwise move it after the
85-
defining op.
86-
}];
87-
}

0 commit comments

Comments
 (0)