Skip to content

Commit 7654f1c

Browse files
[mpmd] Add FragmentScheduleRule
A FragmentScheduleRule contains a sequence of fragments, the order of which determines the relative order of the fragments in the pipeline schedule. This will be used by the custom scheduling API to allow users to create their own custom schedules. PiperOrigin-RevId: 800824524
1 parent ad97da3 commit 7654f1c

3 files changed

Lines changed: 256 additions & 1 deletion

File tree

shardy/dialect/mpmd/ir/fragment_execution_rules.cc

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ bool ParseFragmentOrigin(llvm::cl::Option& opt, llvm::StringRef& arg,
6565
if (!arg.consume_front("(")) {
6666
return false;
6767
}
68-
if (!arg.consumeInteger(10, origin.transpose_count)) {
68+
if (arg.consumeInteger(10, origin.transpose_count)) {
6969
return opt.error("Expected a transpose count");
7070
}
7171
if (!arg.consume_front(")")) {
@@ -187,6 +187,7 @@ void SetFragmentInfo(FragmentOp fragment, const FragmentInfo& metadata,
187187
namespace llvm::cl {
188188

189189
using ::mlir::mpmd::FragmentMergeRule;
190+
using ::mlir::mpmd::FragmentScheduleRule;
190191

191192
template class basic_parser<FragmentMergeRule>;
192193

@@ -228,4 +229,39 @@ void parser<FragmentMergeRule>::printOptionDiff(const Option& opt,
228229

229230
void parser<FragmentMergeRule>::anchor() {}
230231

232+
template class basic_parser<FragmentScheduleRule>;
233+
234+
// Parses a fragment schedule rule string of the form
235+
// "FragmentScheduleRule(ordered_fragments=[<fragment1>-><fragment2>...])"
236+
// <fragment>s are FragmentInfo strings.
237+
bool parser<FragmentScheduleRule>::parse(Option& opt, StringRef, StringRef arg,
238+
FragmentScheduleRule& value) {
239+
if (!arg.consume_front(FragmentScheduleRule::kFragmentScheduleRulePrefix)) {
240+
return opt.error("Expected '" +
241+
FragmentScheduleRule::kFragmentScheduleRulePrefix + "'");
242+
}
243+
while (!arg.starts_with("]")) {
244+
if (mlir::mpmd::ParseFragmentInfo(opt, arg,
245+
value.ordered_fragments.emplace_back())) {
246+
return true; // opt.error was called inside ParseFragmentInfo
247+
}
248+
if (!arg.consume_front("->")) {
249+
break;
250+
}
251+
}
252+
if (!arg.consume_front("])")) {
253+
return opt.error("Expected '])'");
254+
}
255+
return false;
256+
}
257+
258+
void parser<FragmentScheduleRule>::printOptionDiff(
259+
const Option& opt, const FragmentScheduleRule& value,
260+
const OptVal& defaultValue, size_t globalWidth) const {
261+
printOptionName(opt, globalWidth);
262+
outs() << "= " << value << "\n";
263+
}
264+
265+
void parser<FragmentScheduleRule>::anchor() {}
266+
231267
} // namespace llvm::cl

shardy/dialect/mpmd/ir/fragment_execution_rules.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ limitations under the License.
2626
#include "llvm/ADT/DenseMapInfo.h"
2727
#include "llvm/ADT/Hashing.h"
2828
#include "llvm/ADT/STLExtras.h"
29+
#include "llvm/ADT/StringRef.h"
2930
#include "llvm/Support/CommandLine.h"
3031
#include "llvm/Support/raw_ostream.h"
3132
#include "mlir/IR/PatternMatch.h"
@@ -142,6 +143,27 @@ struct FragmentMergeRule {
142143

143144
using FragmentMergeRules = std::vector<FragmentMergeRule>;
144145

146+
// Describes a rule for scheduling fragments. A rule is defined by an ordered
147+
// sequence of fragments. This ordering dictates the execution order of the
148+
// fragments on a given mesh.
149+
struct FragmentScheduleRule {
150+
// The sequence of fragments to be scheduled. The order of fragments in this
151+
// vector defines their execution order.
152+
std::vector<FragmentInfo> ordered_fragments;
153+
154+
static constexpr llvm::StringRef kFragmentScheduleRulePrefix =
155+
"FragmentScheduleRule(ordered_fragments=[";
156+
157+
friend llvm::raw_ostream& operator<<(llvm::raw_ostream& os,
158+
const FragmentScheduleRule& rule) {
159+
os << kFragmentScheduleRulePrefix;
160+
llvm::interleave(rule.ordered_fragments, os, "->");
161+
return os << "])";
162+
}
163+
};
164+
165+
using FragmentScheduleRules = std::vector<FragmentScheduleRule>;
166+
145167
// Returns the fragment info of a fragment op.
146168
FragmentInfo GetFragmentInfo(FragmentOp fragment);
147169

@@ -169,6 +191,22 @@ class parser<mlir::mpmd::FragmentMergeRule>
169191
void anchor() override;
170192
};
171193

194+
extern template class basic_parser<mlir::mpmd::FragmentScheduleRule>;
195+
196+
template <>
197+
class parser<mlir::mpmd::FragmentScheduleRule>
198+
: public basic_parser<mlir::mpmd::FragmentScheduleRule> {
199+
public:
200+
parser(Option& opt) : basic_parser(opt) {}
201+
bool parse(Option& opt, StringRef argName, StringRef arg,
202+
mlir::mpmd::FragmentScheduleRule& value);
203+
StringRef getValueName() const override { return "fragment-schedule-rule"; }
204+
void printOptionDiff(const Option& opt,
205+
const mlir::mpmd::FragmentScheduleRule& value,
206+
const OptVal& defaultValue, size_t globalWidth) const;
207+
void anchor() override;
208+
};
209+
172210
} // namespace llvm::cl
173211

174212
#endif // SHARDY_DIALECT_MPMD_IR_FRAGMENT_EXECUTION_RULES_H_

shardy/dialect/mpmd/ir/fragment_execution_rules_test.cc

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License.
2020
#include <vector>
2121

2222
#include "llvm/ADT/STLExtras.h"
23+
#include "llvm/Support/CommandLine.h"
2324
#include "llvm/Support/raw_ostream.h"
2425
#include "mlir/IR/BuiltinOps.h"
2526
#include "mlir/IR/MLIRContext.h"
@@ -69,6 +70,48 @@ FragmentMergeRule MakeFragmentMergeRule(
6970
return {sources, target};
7071
}
7172

73+
FragmentScheduleRule MakeFragmentScheduleRule(
74+
const std::vector<FragmentInfo>& ordered_fragments) {
75+
return {ordered_fragments};
76+
}
77+
78+
// LLVM's command line classes (OptionCategory, opt) store StringRef arguments
79+
// directly without copying the underlying string data. When these objects are
80+
// created with temporary string literals in test functions, the backing strings
81+
// go out of scope after the test completes, leaving dangling pointers in the
82+
// static GlobalParser->RegisteredOptionCategories.
83+
//
84+
// The functions below use static storage to ensure string literals have static
85+
// storage duration, avoiding the need for manual cleanup. The parser helper
86+
// functions encapsulate opt/parser creation and provide a clean interface for
87+
// tests without exposing StringRef lifetime concerns.
88+
89+
llvm::cl::OptionCategory& getTestOptionCategory() {
90+
static llvm::cl::OptionCategory category("Test Options");
91+
return category;
92+
}
93+
94+
bool parseFragmentMergeRule(llvm::StringRef rule_str, FragmentMergeRule& rule) {
95+
static llvm::cl::opt<FragmentMergeRule> rule_opt(
96+
"fragment-merge-rule",
97+
llvm::cl::desc("Fragment merge rule for testing parser functionality"),
98+
llvm::cl::cat(getTestOptionCategory()));
99+
static llvm::cl::parser<FragmentMergeRule> parser(rule_opt);
100+
101+
return parser.parse(rule_opt, "test-rule", rule_str, rule);
102+
}
103+
104+
bool parseFragmentScheduleRule(llvm::StringRef rule_str,
105+
FragmentScheduleRule& rule) {
106+
static llvm::cl::opt<FragmentScheduleRule> rule_opt(
107+
"fragment-schedule-rule",
108+
llvm::cl::desc("Fragment schedule rule for testing parser functionality"),
109+
llvm::cl::cat(getTestOptionCategory()));
110+
static llvm::cl::parser<FragmentScheduleRule> parser(rule_opt);
111+
112+
return parser.parse(rule_opt, "test-rule", rule_str, rule);
113+
}
114+
72115
TEST(GetFragmentInfoTest, GetFragmentInfo) {
73116
const std::string kProgram = R"mlir(
74117
!mesh_1_tensor_4_8_f32 = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>>
@@ -259,6 +302,144 @@ TEST(FragmentMergeRule, PrintFragmentMergeRule) {
259302
"\"f1\"(123),\"f2\"(456)],stage=1,is_weight_gradient=false))"));
260303
}
261304

305+
TEST(FragmentMergeRuleParser, ParseValidRule) {
306+
FragmentMergeRule expected_rule = MakeFragmentMergeRule(
307+
{MakeFragmentInfo({MakeFragmentOrigin("f1", 123)},
308+
/*stage_id=*/1,
309+
/*call_counter=*/std::nullopt,
310+
/*is_weight_gradient=*/false),
311+
MakeFragmentInfo({MakeFragmentOrigin("f2", 456)},
312+
/*stage_id=*/1,
313+
/*call_counter=*/std::nullopt,
314+
/*is_weight_gradient=*/true)},
315+
MakeFragmentInfo(
316+
{MakeFragmentOrigin("f1", 123), MakeFragmentOrigin("f2", 456)},
317+
/*stage_id=*/1,
318+
/*call_counter=*/std::nullopt,
319+
/*is_weight_gradient=*/false));
320+
// We first construct the rule and print it to a string. Then we parse that
321+
// string to ensure that the printed form of a rule is directly compatible
322+
// with the format the parser expects.
323+
std::string rule_str;
324+
llvm::raw_string_ostream os(rule_str);
325+
os << expected_rule;
326+
327+
FragmentMergeRule rule;
328+
bool result = parseFragmentMergeRule(rule_str, rule);
329+
330+
EXPECT_FALSE(result);
331+
332+
ASSERT_EQ(rule.sources.size(), 2);
333+
ExpectFragmentInfoEq(rule.sources[0], expected_rule.sources[0]);
334+
ExpectFragmentInfoEq(rule.sources[1], expected_rule.sources[1]);
335+
ExpectFragmentInfoEq(rule.target, expected_rule.target);
336+
}
337+
338+
struct InvalidRuleTestParams {
339+
std::string test_name;
340+
std::string invalid_rule_str;
341+
};
342+
343+
class FragmentMergeRuleParserInvalidSyntaxTest
344+
: public ::testing::TestWithParam<InvalidRuleTestParams> {};
345+
346+
TEST_P(FragmentMergeRuleParserInvalidSyntaxTest, ParseInvalidRule) {
347+
const auto& params = GetParam();
348+
FragmentMergeRule rule;
349+
bool result = parseFragmentMergeRule(params.invalid_rule_str, rule);
350+
351+
EXPECT_TRUE(result);
352+
}
353+
354+
INSTANTIATE_TEST_SUITE_P(
355+
FragmentMergeRuleParser, FragmentMergeRuleParserInvalidSyntaxTest,
356+
testing::Values(
357+
InvalidRuleTestParams{"MissingPrefix",
358+
"sources=[FragmentInfo(origins=[\"f1\"(123)])]"},
359+
InvalidRuleTestParams{
360+
"MissingSources",
361+
"FragmentMergeRule(target=FragmentInfo(origins=[\"f1\"(123)]))"},
362+
InvalidRuleTestParams{"MissingTarget",
363+
"FragmentMergeRule(sources=[FragmentInfo(origins="
364+
"[\"f1\"(123)])])"}),
365+
[](const testing::TestParamInfo<
366+
FragmentMergeRuleParserInvalidSyntaxTest::ParamType>& info) {
367+
return info.param.test_name;
368+
});
369+
370+
TEST(FragmentScheduleRule, PrintFragmentScheduleRule) {
371+
FragmentScheduleRule rule = MakeFragmentScheduleRule(
372+
{MakeFragmentInfo({MakeFragmentOrigin("f1", 123)}, /*stage_id=*/1),
373+
MakeFragmentInfo({MakeFragmentOrigin("f2", 456)}, /*stage_id=*/1)});
374+
std::string str;
375+
llvm::raw_string_ostream os(str);
376+
os << rule;
377+
EXPECT_THAT(str, Eq("FragmentScheduleRule(ordered_fragments=["
378+
"FragmentInfo(origins=[\"f1\"(123)],stage=1,is_weight_"
379+
"gradient=false)->"
380+
"FragmentInfo(origins=[\"f2\"(456)],stage=1,is_weight_"
381+
"gradient=false)])"));
382+
}
383+
384+
TEST(FragmentScheduleRuleParser, ParseValidRule) {
385+
FragmentScheduleRule expected_rule = MakeFragmentScheduleRule(
386+
{MakeFragmentInfo({MakeFragmentOrigin("f1", 123)},
387+
/*stage_id=*/1,
388+
/*call_counter=*/std::nullopt,
389+
/*is_weight_gradient=*/false),
390+
MakeFragmentInfo({MakeFragmentOrigin("f2", 456)},
391+
/*stage_id=*/1,
392+
/*call_counter=*/std::nullopt,
393+
/*is_weight_gradient=*/true)});
394+
// We first construct the rule and print it to a string. Then we parse that
395+
// string to ensure that the printed form of a rule is directly compatible
396+
// with the format the parser expects.
397+
std::string rule_str;
398+
llvm::raw_string_ostream os(rule_str);
399+
os << expected_rule;
400+
401+
FragmentScheduleRule rule;
402+
bool result = parseFragmentScheduleRule(rule_str, rule);
403+
404+
EXPECT_FALSE(result);
405+
406+
ASSERT_EQ(rule.ordered_fragments.size(), 2);
407+
ExpectFragmentInfoEq(rule.ordered_fragments[0],
408+
expected_rule.ordered_fragments[0]);
409+
ExpectFragmentInfoEq(rule.ordered_fragments[1],
410+
expected_rule.ordered_fragments[1]);
411+
}
412+
413+
class FragmentScheduleRuleParserInvalidSyntaxTest
414+
: public ::testing::TestWithParam<InvalidRuleTestParams> {};
415+
416+
TEST_P(FragmentScheduleRuleParserInvalidSyntaxTest, ParseInvalidRule) {
417+
const auto& params = GetParam();
418+
FragmentScheduleRule rule;
419+
bool result = parseFragmentScheduleRule(params.invalid_rule_str, rule);
420+
421+
EXPECT_TRUE(result);
422+
}
423+
424+
INSTANTIATE_TEST_SUITE_P(
425+
FragmentScheduleRuleParser, FragmentScheduleRuleParserInvalidSyntaxTest,
426+
testing::Values(
427+
InvalidRuleTestParams{"MissingPrefix",
428+
"[FragmentInfo(origins=[\"f1\"(123)])->"
429+
"FragmentInfo(origins=[\"f2\"(456)])])"},
430+
InvalidRuleTestParams{
431+
"MissingArrow",
432+
"FragmentScheduleRule(ordered_fragments=[FragmentInfo(origins=["
433+
"\"f1\"(123)]) FragmentInfo(origins=[\"f2\"(456)])])"},
434+
InvalidRuleTestParams{
435+
"MissingClosingBrackets",
436+
"FragmentScheduleRule(ordered_fragments=[FragmentInfo(origins=["
437+
"\"f1\"(123)])->FragmentInfo(origins=[\"f2\"(456)])"}),
438+
[](const testing::TestParamInfo<
439+
FragmentScheduleRuleParserInvalidSyntaxTest::ParamType>& info) {
440+
return info.param.test_name;
441+
});
442+
262443
TEST(FragmentInfoMapInfoTest, IsEqual) {
263444
FragmentInfo info1 = MakeFragmentInfo({MakeFragmentOrigin("f1", 123)});
264445
FragmentInfo info2 = MakeFragmentInfo({MakeFragmentOrigin("f1", 123)});

0 commit comments

Comments
 (0)