Skip to content

Commit 9e06b27

Browse files
[mpmd] Move pre rule-based scheduling/merging passes to import pipeline
Rule generation will occur after import, so further changes to fragments should not occur after import until rule-based scheduling/merging. This CL should be a no-op for non rule-based merge pass users. PiperOrigin-RevId: 803073873
1 parent b9bda19 commit 9e06b27

17 files changed

Lines changed: 344 additions & 245 deletions

shardy/dialect/mpmd/ir/fragment_execution_rules.cc

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,23 @@ bool ParseFragmentInfo(llvm::cl::Option& opt, llvm::StringRef& arg,
125125
"Expected 'kKeepTransferred' or 'kDropTransferred' for "
126126
"'split_type'");
127127
}
128+
} else if (arg.consume_front("mesh_name=")) {
129+
if (!info.mesh_name.empty()) {
130+
return opt.error("'mesh_name' specified more than once");
131+
}
132+
if (!arg.consume_front("\"")) {
133+
return opt.error("Expected '\"' to start 'mesh_name'");
134+
}
135+
auto [mesh_name, rest] = arg.split('"');
136+
if (mesh_name == arg) {
137+
return opt.error("Expected '\"' to end 'mesh_name'");
138+
}
139+
info.mesh_name = mesh_name.str();
140+
arg = rest;
128141
} else {
129142
return opt.error(
130-
"Expected 'stage=', 'call_counter=', or "
131-
"'split_type=' after ','");
143+
"Expected 'stage=', 'call_counter=', 'split_type=', or "
144+
"'mesh_name=' after ','");
132145
}
133146
}
134147
if (!arg.consume_front(")")) {
@@ -157,7 +170,8 @@ FragmentInfo GetFragmentInfo(FragmentOp fragment) {
157170
std::optional<int64_t> call_counter = TryToFindCallCounter(fragment);
158171
std::vector<FragmentOrigin> origins = GetFragmentOrigins(fragment);
159172
std::optional<SplitFragmentType> split_type = GetSplitFragmentType(fragment);
160-
return FragmentInfo{origins, stage_id, call_counter, split_type};
173+
return FragmentInfo{origins, stage_id, call_counter, split_type,
174+
fragment.getMeshName().str()};
161175
}
162176

163177
void SetFragmentInfo(FragmentOp fragment, const FragmentInfo& metadata,
@@ -197,6 +211,9 @@ void SetFragmentInfo(FragmentOp fragment, const FragmentInfo& metadata,
197211
fragment->removeAttr(kSplitDropTransferredAttrName);
198212
fragment->removeAttr(kSplitKeepTransferredAttrName);
199213
}
214+
215+
fragment.setMeshName(
216+
StringAttr::get(rewriter.getContext(), metadata.mesh_name));
200217
}
201218

202219
} // namespace mlir::mpmd

shardy/dialect/mpmd/ir/fragment_execution_rules.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,12 @@ struct FragmentInfo {
110110
std::optional<int> stage_id;
111111
std::optional<int> call_counter;
112112
std::optional<SplitFragmentType> split_type;
113+
std::string mesh_name;
113114

114115
bool operator==(const FragmentInfo& other) const {
115116
return llvm::equal(origins, other.origins) && stage_id == other.stage_id &&
116-
call_counter == other.call_counter && split_type == other.split_type;
117+
call_counter == other.call_counter &&
118+
split_type == other.split_type && mesh_name == other.mesh_name;
117119
}
118120

119121
bool operator!=(const FragmentInfo& other) const { return !(*this == other); }
@@ -133,6 +135,7 @@ struct FragmentInfo {
133135
if (info.split_type.has_value()) {
134136
os << ",split_type=" << *info.split_type;
135137
}
138+
os << ",mesh_name=\"" << info.mesh_name << "\"";
136139
os << ")";
137140
return os;
138141
}
@@ -141,8 +144,8 @@ struct FragmentInfo {
141144
struct FragmentInfoMapInfo : public DenseMapInfo<FragmentInfo> {
142145
static unsigned getHashValue(const FragmentInfo& info) {
143146
return llvm::hash_combine(llvm::hash_combine_range(info.origins),
144-
info.stage_id, info.call_counter,
145-
info.split_type);
147+
info.stage_id, info.call_counter, info.split_type,
148+
info.mesh_name);
146149
}
147150
static bool isEqual(const FragmentInfo& lhs, const FragmentInfo& rhs) {
148151
return lhs == rhs;
@@ -152,14 +155,16 @@ struct FragmentInfoMapInfo : public DenseMapInfo<FragmentInfo> {
152155
return FragmentInfo{/*origins=*/{},
153156
/*stage_id=*/DenseMapInfo<int>::getEmptyKey(),
154157
/*call_counter=*/DenseMapInfo<int>::getEmptyKey(),
155-
/*split_type=*/std::nullopt};
158+
/*split_type=*/std::nullopt,
159+
/*mesh_name=*/""};
156160
}
157161

158162
static inline FragmentInfo getTombstoneKey() {
159163
return FragmentInfo{/*origins=*/{},
160164
/*stage_id=*/DenseMapInfo<int>::getTombstoneKey(),
161165
/*call_counter=*/DenseMapInfo<int>::getTombstoneKey(),
162-
/*split_type=*/SplitFragmentType::kDropTransferred};
166+
/*split_type=*/SplitFragmentType::kDropTransferred,
167+
/*mesh_name=*/"__tombstone__"};
163168
}
164169
};
165170

shardy/dialect/mpmd/ir/fragment_execution_rules_test.cc

Lines changed: 62 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,11 @@ FragmentOrigin MakeFragmentOrigin(const std::string& computation_name,
5959
}
6060

6161
FragmentInfo MakeFragmentInfo(
62-
const std::vector<FragmentOrigin>& origins,
62+
const std::vector<FragmentOrigin>& origins, const std::string& mesh_name,
6363
std::optional<int> stage_id = std::nullopt,
6464
std::optional<int> call_counter = std::nullopt,
6565
std::optional<SplitFragmentType> split_type = std::nullopt) {
66-
return {origins, stage_id, call_counter, split_type};
66+
return {origins, stage_id, call_counter, split_type, mesh_name};
6767
}
6868

6969
FragmentMergeRule MakeFragmentMergeRule(
@@ -139,6 +139,7 @@ TEST(GetFragmentInfoTest, GetFragmentInfo) {
139139
fragment_info,
140140
MakeFragmentInfo(
141141
{MakeFragmentOrigin("f1", 123), MakeFragmentOrigin("f2", 123)},
142+
/*mesh_name=*/"m1",
142143
/*stage_id=*/std::nullopt,
143144
/*call_counter=*/std::nullopt, /*split_type=*/std::nullopt));
144145
}
@@ -187,12 +188,14 @@ INSTANTIATE_TEST_SUITE_P(
187188
testing::Values(
188189
SetFragmentInfoTestParams{
189190
"WithStageAndCallCounter",
190-
MakeFragmentInfo({MakeFragmentOrigin("f3", 456)}, /*stage_id=*/1,
191-
/*call_counter=*/2, /*split_type=*/std::nullopt)},
191+
MakeFragmentInfo({MakeFragmentOrigin("f3", 456)},
192+
/*mesh_name=*/"m1",
193+
/*stage_id=*/1, /*call_counter=*/2,
194+
/*split_type=*/std::nullopt)},
192195
SetFragmentInfoTestParams{
193196
"WithWeightGradient",
194197
MakeFragmentInfo(
195-
{MakeFragmentOrigin("f4", 789)},
198+
{MakeFragmentOrigin("f4", 789)}, /*mesh_name=*/"m1",
196199
/*stage_id=*/std::nullopt,
197200
/*call_counter=*/std::nullopt,
198201
/*split_type=*/SplitFragmentType::kDropTransferred)}),
@@ -224,6 +227,7 @@ TEST(SetFragmentInfoTest, RemovesSplitDropTransferred) {
224227

225228
IRRewriter rewriter(&context);
226229
FragmentInfo info = MakeFragmentInfo({MakeFragmentOrigin("f1", 0)},
230+
/*mesh_name=*/"m1",
227231
/*stage_id=*/std::nullopt,
228232
/*call_counter=*/std::nullopt,
229233
/*split_type=*/std::nullopt);
@@ -258,66 +262,68 @@ INSTANTIATE_TEST_SUITE_P(
258262
"NoSplitType",
259263
MakeFragmentInfo({MakeFragmentOrigin("f1", 123),
260264
MakeFragmentOrigin("f2", 456)},
261-
/*stage_id=*/1, /*call_counter=*/2,
262-
/*split_type=*/std::nullopt),
265+
/*mesh_name=*/"m1", /*stage_id=*/1,
266+
/*call_counter=*/2, /*split_type=*/std::nullopt),
263267
"FragmentInfo(origins=[\"f1\"(123),\"f2\"(456)],stage=1,call_"
264-
"counter=2)"},
268+
"counter=2,mesh_name=\"m1\")"},
265269
PrintFragmentInfoTestParams{
266270
"WithSplitTypeDropTransferred",
267271
MakeFragmentInfo(
268-
{MakeFragmentOrigin("f1", 123)}, /*stage_id=*/1,
269-
/*call_counter=*/2,
272+
{MakeFragmentOrigin("f1", 123)}, /*mesh_name=*/"m1",
273+
/*stage_id=*/1, /*call_counter=*/2,
270274
/*split_type=*/SplitFragmentType::kDropTransferred),
271275
"FragmentInfo(origins=[\"f1\"(123)],stage=1,call_counter=2,"
272-
"split_type=kDropTransferred)"},
276+
"split_type=kDropTransferred,mesh_name=\"m1\")"},
273277
PrintFragmentInfoTestParams{
274278
"WithSplitTypeKeepTransferred",
275279
MakeFragmentInfo(
276-
{MakeFragmentOrigin("f1", 123)}, /*stage_id=*/1,
277-
/*call_counter=*/2,
280+
{MakeFragmentOrigin("f1", 123)}, /*mesh_name=*/"m1",
281+
/*stage_id=*/1, /*call_counter=*/2,
278282
/*split_type=*/SplitFragmentType::kKeepTransferred),
279283
"FragmentInfo(origins=[\"f1\"(123)],stage=1,call_counter=2,"
280-
"split_type=kKeepTransferred)"},
284+
"split_type=kKeepTransferred,mesh_name=\"m1\")"},
281285
PrintFragmentInfoTestParams{
282286
"OnlyRequiredFields",
283-
MakeFragmentInfo({MakeFragmentOrigin("f1", 123)}),
284-
"FragmentInfo(origins=[\"f1\"(123)])"}),
287+
MakeFragmentInfo({MakeFragmentOrigin("f1", 123)},
288+
/*mesh_name=*/"m1"),
289+
"FragmentInfo(origins=[\"f1\"(123)],mesh_name=\"m1\")"}),
285290
[](const testing::TestParamInfo<PrintFragmentInfoTest::ParamType>& info) {
286291
return info.param.test_name;
287292
});
288293

289294
TEST(FragmentMergeRule, PrintFragmentMergeRule) {
290295
FragmentMergeRule rule = MakeFragmentMergeRule(
291-
{MakeFragmentInfo({MakeFragmentOrigin("f1", 123)}, /*stage_id=*/1),
292-
MakeFragmentInfo({MakeFragmentOrigin("f2", 456)}, /*stage_id=*/1)},
296+
{MakeFragmentInfo({MakeFragmentOrigin("f1", 123)}, /*mesh_name=*/"m1",
297+
/*stage_id=*/1),
298+
MakeFragmentInfo({MakeFragmentOrigin("f2", 456)}, /*mesh_name=*/"m1",
299+
/*stage_id=*/1)},
293300
MakeFragmentInfo(
294301
{MakeFragmentOrigin("f1", 123), MakeFragmentOrigin("f2", 456)},
295-
/*stage_id=*/1, /*call_counter=*/std::nullopt,
302+
/*mesh_name=*/"m1", /*stage_id=*/1, /*call_counter=*/std::nullopt,
296303
/*split_type=*/std::nullopt));
297304
std::string str;
298305
llvm::raw_string_ostream os(str);
299306
os << rule;
300-
EXPECT_THAT(str, Eq("FragmentMergeRule(sources=["
301-
"FragmentInfo(origins=[\"f1\"(123)],stage=1),"
302-
"FragmentInfo(origins=[\"f2\"(456)],stage=1)],"
303-
"target=FragmentInfo(origins=["
304-
"\"f1\"(123),\"f2\"(456)],stage=1))"));
307+
EXPECT_THAT(
308+
str, Eq("FragmentMergeRule(sources=["
309+
"FragmentInfo(origins=[\"f1\"(123)],stage=1,mesh_name=\"m1\"),"
310+
"FragmentInfo(origins=[\"f2\"(456)],stage=1,mesh_name=\"m1\")],"
311+
"target=FragmentInfo(origins=["
312+
"\"f1\"(123),\"f2\"(456)],stage=1,mesh_name=\"m1\"))"));
305313
}
306314

307315
TEST(FragmentMergeRuleParser, ParseValidRule) {
308316
FragmentMergeRule expected_rule = MakeFragmentMergeRule(
309-
{MakeFragmentInfo({MakeFragmentOrigin("f1", 123)}, /*stage_id=*/1,
310-
/*call_counter=*/std::nullopt,
317+
{MakeFragmentInfo({MakeFragmentOrigin("f1", 123)}, /*mesh_name=*/"m1",
318+
/*stage_id=*/1, /*call_counter=*/std::nullopt,
311319
/*split_type=*/std::nullopt),
312-
MakeFragmentInfo({MakeFragmentOrigin("f2", 456)},
313-
/*stage_id=*/1,
314-
/*call_counter=*/std::nullopt,
320+
MakeFragmentInfo({MakeFragmentOrigin("f2", 456)}, /*mesh_name=*/"m1",
321+
/*stage_id=*/1, /*call_counter=*/std::nullopt,
315322
/*split_type=*/SplitFragmentType::kDropTransferred)},
316323
MakeFragmentInfo(
317324
{MakeFragmentOrigin("f1", 123), MakeFragmentOrigin("f2", 456)},
318-
/*stage_id=*/1,
319-
/*call_counter=*/std::nullopt,
320-
/*split_type=*/std::nullopt));
325+
/*mesh_name=*/"m1", /*stage_id=*/1,
326+
/*call_counter=*/std::nullopt, /*split_type=*/std::nullopt));
321327
// We first construct the rule and print it to a string. Then we parse that
322328
// string to ensure that the printed form of a rule is directly compatible
323329
// with the format the parser expects.
@@ -370,24 +376,27 @@ INSTANTIATE_TEST_SUITE_P(
370376

371377
TEST(FragmentScheduleRule, PrintFragmentScheduleRule) {
372378
FragmentScheduleRule rule = MakeFragmentScheduleRule(
373-
{MakeFragmentInfo({MakeFragmentOrigin("f1", 123)}, /*stage_id=*/1),
374-
MakeFragmentInfo({MakeFragmentOrigin("f2", 456)}, /*stage_id=*/2)});
379+
{MakeFragmentInfo({MakeFragmentOrigin("f1", 123)}, /*mesh_name=*/"m1",
380+
/*stage_id=*/1),
381+
MakeFragmentInfo({MakeFragmentOrigin("f2", 456)}, /*mesh_name=*/"m1",
382+
/*stage_id=*/2)});
375383
std::string str;
376384
llvm::raw_string_ostream os(str);
377385
os << rule;
378-
EXPECT_THAT(str, Eq("FragmentScheduleRule(ordered_fragments=["
379-
"FragmentInfo(origins=[\"f1\"(123)],stage=1)->"
380-
"FragmentInfo(origins=[\"f2\"(456)],stage=2)])"));
386+
EXPECT_THAT(
387+
str,
388+
Eq("FragmentScheduleRule(ordered_fragments=["
389+
"FragmentInfo(origins=[\"f1\"(123)],stage=1,mesh_name=\"m1\")->"
390+
"FragmentInfo(origins=[\"f2\"(456)],stage=2,mesh_name=\"m1\")])"));
381391
}
382392

383393
TEST(FragmentScheduleRuleParser, ParseValidRule) {
384394
FragmentScheduleRule expected_rule = MakeFragmentScheduleRule(
385-
{MakeFragmentInfo({MakeFragmentOrigin("f1", 123)}, /*stage_id=*/1,
386-
/*call_counter=*/std::nullopt,
395+
{MakeFragmentInfo({MakeFragmentOrigin("f1", 123)}, /*mesh_name=*/"m1",
396+
/*stage_id=*/1, /*call_counter=*/std::nullopt,
387397
/*split_type=*/std::nullopt),
388-
MakeFragmentInfo({MakeFragmentOrigin("f2", 456)},
389-
/*stage_id=*/1,
390-
/*call_counter=*/std::nullopt,
398+
MakeFragmentInfo({MakeFragmentOrigin("f2", 456)}, /*mesh_name=*/"m1",
399+
/*stage_id=*/1, /*call_counter=*/std::nullopt,
391400
/*split_type=*/SplitFragmentType::kDropTransferred)});
392401
// We first construct the rule and print it to a string. Then we parse that
393402
// string to ensure that the printed form of a rule is directly compatible
@@ -439,18 +448,18 @@ INSTANTIATE_TEST_SUITE_P(
439448
});
440449

441450
TEST(FragmentInfoMapInfoTest, IsEqual) {
442-
FragmentInfo info1 = MakeFragmentInfo({MakeFragmentOrigin("f1", 123)});
443-
FragmentInfo info2 = MakeFragmentInfo({MakeFragmentOrigin("f1", 123)});
444-
FragmentInfo info3 = MakeFragmentInfo({MakeFragmentOrigin("f2", 456)});
451+
FragmentInfo info1 = MakeFragmentInfo({MakeFragmentOrigin("f1", 123)}, "m1");
452+
FragmentInfo info2 = MakeFragmentInfo({MakeFragmentOrigin("f1", 123)}, "m1");
453+
FragmentInfo info3 = MakeFragmentInfo({MakeFragmentOrigin("f2", 456)}, "m1");
445454

446455
EXPECT_TRUE(FragmentInfoMapInfo::isEqual(info1, info2));
447456
EXPECT_FALSE(FragmentInfoMapInfo::isEqual(info1, info3));
448457
}
449458

450459
TEST(FragmentInfoMapInfoTest, GetHashValue) {
451-
FragmentInfo info1 = MakeFragmentInfo({MakeFragmentOrigin("f1", 123)});
452-
FragmentInfo info2 = MakeFragmentInfo({MakeFragmentOrigin("f1", 123)});
453-
FragmentInfo info3 = MakeFragmentInfo({MakeFragmentOrigin("f2", 456)});
460+
FragmentInfo info1 = MakeFragmentInfo({MakeFragmentOrigin("f1", 123)}, "m1");
461+
FragmentInfo info2 = MakeFragmentInfo({MakeFragmentOrigin("f1", 123)}, "m1");
462+
FragmentInfo info3 = MakeFragmentInfo({MakeFragmentOrigin("f2", 456)}, "m1");
454463

455464
EXPECT_EQ(FragmentInfoMapInfo::getHashValue(info1),
456465
FragmentInfoMapInfo::getHashValue(info2));
@@ -462,7 +471,7 @@ TEST(FragmentInfoMapInfoTest, GetHashValue) {
462471
TEST(FragmentInfoMapInfoTest, SpecialKeys) {
463472
FragmentInfo emptyKey = FragmentInfoMapInfo::getEmptyKey();
464473
FragmentInfo tombstoneKey = FragmentInfoMapInfo::getTombstoneKey();
465-
FragmentInfo info1 = MakeFragmentInfo({MakeFragmentOrigin("f1", 123)});
474+
FragmentInfo info1 = MakeFragmentInfo({MakeFragmentOrigin("f1", 123)}, "m1");
466475

467476
EXPECT_FALSE(FragmentInfoMapInfo::isEqual(emptyKey, info1));
468477
EXPECT_FALSE(FragmentInfoMapInfo::isEqual(tombstoneKey, info1));
@@ -472,8 +481,8 @@ TEST(FragmentInfoMapInfoTest, SpecialKeys) {
472481
TEST(FragmentInfoMapInfoTest, DenseMapIntegration) {
473482
llvm::DenseMap<FragmentInfo, int, FragmentInfoMapInfo> map;
474483

475-
FragmentInfo info1 = MakeFragmentInfo({MakeFragmentOrigin("f1", 123)});
476-
FragmentInfo info2 = MakeFragmentInfo({MakeFragmentOrigin("f2", 456)});
484+
FragmentInfo info1 = MakeFragmentInfo({MakeFragmentOrigin("f1", 123)}, "m1");
485+
FragmentInfo info2 = MakeFragmentInfo({MakeFragmentOrigin("f2", 456)}, "m1");
477486

478487
map[info1] = 1;
479488
map[info2] = 2;
@@ -482,7 +491,8 @@ TEST(FragmentInfoMapInfoTest, DenseMapIntegration) {
482491
EXPECT_EQ(map[info1], 1);
483492
EXPECT_EQ(map[info2], 2);
484493

485-
FragmentInfo info1_copy = MakeFragmentInfo({MakeFragmentOrigin("f1", 123)});
494+
FragmentInfo info1_copy =
495+
MakeFragmentInfo({MakeFragmentOrigin("f1", 123)}, "m1");
486496
EXPECT_TRUE(map.contains(info1_copy));
487497
EXPECT_EQ(map[info1_copy], 1);
488498

shardy/dialect/mpmd/transforms/common/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,15 @@ cc_library(
3939
"merge_transfers.cc",
4040
"remove_transfer_cycles.cc",
4141
"rule_based_merge.cc",
42+
"scheduler_preprocess.cc",
4243
"split_bwd_fragments.cc",
4344
"uniquify_function_inputs_outputs.cc",
4445
"unroll_for_loops.cc",
4546
],
4647
hdrs = [
4748
"merge_fragments.h",
4849
"passes.h",
50+
"scheduler_preprocess.h",
4951
],
5052
deps = [
5153
":distributed_function_pass",

shardy/dialect/mpmd/transforms/common/passes.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,3 +486,21 @@ def UniquifyFunctionInputsOutputsPass :
486486

487487
let dependentDialects = ["mlir::mpmd::MpmdDialect"];
488488
}
489+
490+
def SchedulingUnitVerifierPass :
491+
PassBase<"mpmd-scheduling-units-verifier", "DistributedFunctionPass"> {
492+
let summary = "Verifies if the program contains the required scheduling units.";
493+
}
494+
495+
// TODO: b/378099938 - Remove this pass once we have a better way to handle
496+
// transfers while merging fragments. We need this now because having a transfer
497+
// in between two fragments prevents the merge pass from merging them.
498+
def MoveTransfersToProducerPass :
499+
PassBase<"mpmd-move-transfers-to-producer", "DistributedFunctionPass"> {
500+
let summary = "Moves transfers next to their producers.";
501+
let description = [{
502+
Moves transfers next to their producers: if the operand is a block argument,
503+
move the transfer to the beginning of the block, otherwise move it after the
504+
defining op.
505+
}];
506+
}

0 commit comments

Comments
 (0)