@@ -20,6 +20,7 @@ limitations under the License.
2020#include < vector>
2121
2222#include " llvm/ADT/STLExtras.h"
23+ #include " llvm/Support/CommandLine.h"
2324#include " llvm/Support/raw_ostream.h"
2425#include " mlir/IR/BuiltinOps.h"
2526#include " mlir/IR/MLIRContext.h"
@@ -69,6 +70,48 @@ FragmentMergeRule MakeFragmentMergeRule(
6970 return {sources, target};
7071}
7172
73+ FragmentScheduleRule MakeFragmentScheduleRule (
74+ const std::vector<FragmentInfo>& ordered_fragments) {
75+ return {ordered_fragments};
76+ }
77+
78+ // LLVM's command line classes (OptionCategory, opt) store StringRef arguments
79+ // directly without copying the underlying string data. When these objects are
80+ // created with temporary string literals in test functions, the backing strings
81+ // go out of scope after the test completes, leaving dangling pointers in the
82+ // static GlobalParser->RegisteredOptionCategories.
83+ //
84+ // The functions below use static storage to ensure string literals have static
85+ // storage duration, avoiding the need for manual cleanup. The parser helper
86+ // functions encapsulate opt/parser creation and provide a clean interface for
87+ // tests without exposing StringRef lifetime concerns.
88+
89+ llvm::cl::OptionCategory& getTestOptionCategory () {
90+ static llvm::cl::OptionCategory category (" Test Options" );
91+ return category;
92+ }
93+
94+ bool parseFragmentMergeRule (llvm::StringRef rule_str, FragmentMergeRule& rule) {
95+ static llvm::cl::opt<FragmentMergeRule> rule_opt (
96+ " fragment-merge-rule" ,
97+ llvm::cl::desc (" Fragment merge rule for testing parser functionality" ),
98+ llvm::cl::cat (getTestOptionCategory ()));
99+ static llvm::cl::parser<FragmentMergeRule> parser (rule_opt);
100+
101+ return parser.parse (rule_opt, " test-rule" , rule_str, rule);
102+ }
103+
104+ bool parseFragmentScheduleRule (llvm::StringRef rule_str,
105+ FragmentScheduleRule& rule) {
106+ static llvm::cl::opt<FragmentScheduleRule> rule_opt (
107+ " fragment-schedule-rule" ,
108+ llvm::cl::desc (" Fragment schedule rule for testing parser functionality" ),
109+ llvm::cl::cat (getTestOptionCategory ()));
110+ static llvm::cl::parser<FragmentScheduleRule> parser (rule_opt);
111+
112+ return parser.parse (rule_opt, " test-rule" , rule_str, rule);
113+ }
114+
72115TEST (GetFragmentInfoTest, GetFragmentInfo) {
73116 const std::string kProgram = R"mlir(
74117 !mesh_1_tensor_4_8_f32 = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>>
@@ -259,6 +302,144 @@ TEST(FragmentMergeRule, PrintFragmentMergeRule) {
259302 " \" f1\" (123),\" f2\" (456)],stage=1,is_weight_gradient=false))" ));
260303}
261304
305+ TEST (FragmentMergeRuleParser, ParseValidRule) {
306+ FragmentMergeRule expected_rule = MakeFragmentMergeRule (
307+ {MakeFragmentInfo ({MakeFragmentOrigin (" f1" , 123 )},
308+ /* stage_id=*/ 1 ,
309+ /* call_counter=*/ std::nullopt ,
310+ /* is_weight_gradient=*/ false ),
311+ MakeFragmentInfo ({MakeFragmentOrigin (" f2" , 456 )},
312+ /* stage_id=*/ 1 ,
313+ /* call_counter=*/ std::nullopt ,
314+ /* is_weight_gradient=*/ true )},
315+ MakeFragmentInfo (
316+ {MakeFragmentOrigin (" f1" , 123 ), MakeFragmentOrigin (" f2" , 456 )},
317+ /* stage_id=*/ 1 ,
318+ /* call_counter=*/ std::nullopt ,
319+ /* is_weight_gradient=*/ false ));
320+ // We first construct the rule and print it to a string. Then we parse that
321+ // string to ensure that the printed form of a rule is directly compatible
322+ // with the format the parser expects.
323+ std::string rule_str;
324+ llvm::raw_string_ostream os (rule_str);
325+ os << expected_rule;
326+
327+ FragmentMergeRule rule;
328+ bool result = parseFragmentMergeRule (rule_str, rule);
329+
330+ EXPECT_FALSE (result);
331+
332+ ASSERT_EQ (rule.sources .size (), 2 );
333+ ExpectFragmentInfoEq (rule.sources [0 ], expected_rule.sources [0 ]);
334+ ExpectFragmentInfoEq (rule.sources [1 ], expected_rule.sources [1 ]);
335+ ExpectFragmentInfoEq (rule.target , expected_rule.target );
336+ }
337+
338+ struct InvalidRuleTestParams {
339+ std::string test_name;
340+ std::string invalid_rule_str;
341+ };
342+
343+ class FragmentMergeRuleParserInvalidSyntaxTest
344+ : public ::testing::TestWithParam<InvalidRuleTestParams> {};
345+
346+ TEST_P (FragmentMergeRuleParserInvalidSyntaxTest, ParseInvalidRule) {
347+ const auto & params = GetParam ();
348+ FragmentMergeRule rule;
349+ bool result = parseFragmentMergeRule (params.invalid_rule_str , rule);
350+
351+ EXPECT_TRUE (result);
352+ }
353+
354+ INSTANTIATE_TEST_SUITE_P (
355+ FragmentMergeRuleParser, FragmentMergeRuleParserInvalidSyntaxTest,
356+ testing::Values (
357+ InvalidRuleTestParams{" MissingPrefix" ,
358+ " sources=[FragmentInfo(origins=[\" f1\" (123)])]" },
359+ InvalidRuleTestParams{
360+ " MissingSources" ,
361+ " FragmentMergeRule(target=FragmentInfo(origins=[\" f1\" (123)]))" },
362+ InvalidRuleTestParams{" MissingTarget" ,
363+ " FragmentMergeRule(sources=[FragmentInfo(origins="
364+ " [\" f1\" (123)])])" }),
365+ [](const testing::TestParamInfo<
366+ FragmentMergeRuleParserInvalidSyntaxTest::ParamType>& info) {
367+ return info.param .test_name ;
368+ });
369+
370+ TEST (FragmentScheduleRule, PrintFragmentScheduleRule) {
371+ FragmentScheduleRule rule = MakeFragmentScheduleRule (
372+ {MakeFragmentInfo ({MakeFragmentOrigin (" f1" , 123 )}, /* stage_id=*/ 1 ),
373+ MakeFragmentInfo ({MakeFragmentOrigin (" f2" , 456 )}, /* stage_id=*/ 1 )});
374+ std::string str;
375+ llvm::raw_string_ostream os (str);
376+ os << rule;
377+ EXPECT_THAT (str, Eq (" FragmentScheduleRule(ordered_fragments=["
378+ " FragmentInfo(origins=[\" f1\" (123)],stage=1,is_weight_"
379+ " gradient=false)->"
380+ " FragmentInfo(origins=[\" f2\" (456)],stage=1,is_weight_"
381+ " gradient=false)])" ));
382+ }
383+
384+ TEST (FragmentScheduleRuleParser, ParseValidRule) {
385+ FragmentScheduleRule expected_rule = MakeFragmentScheduleRule (
386+ {MakeFragmentInfo ({MakeFragmentOrigin (" f1" , 123 )},
387+ /* stage_id=*/ 1 ,
388+ /* call_counter=*/ std::nullopt ,
389+ /* is_weight_gradient=*/ false ),
390+ MakeFragmentInfo ({MakeFragmentOrigin (" f2" , 456 )},
391+ /* stage_id=*/ 1 ,
392+ /* call_counter=*/ std::nullopt ,
393+ /* is_weight_gradient=*/ true )});
394+ // We first construct the rule and print it to a string. Then we parse that
395+ // string to ensure that the printed form of a rule is directly compatible
396+ // with the format the parser expects.
397+ std::string rule_str;
398+ llvm::raw_string_ostream os (rule_str);
399+ os << expected_rule;
400+
401+ FragmentScheduleRule rule;
402+ bool result = parseFragmentScheduleRule (rule_str, rule);
403+
404+ EXPECT_FALSE (result);
405+
406+ ASSERT_EQ (rule.ordered_fragments .size (), 2 );
407+ ExpectFragmentInfoEq (rule.ordered_fragments [0 ],
408+ expected_rule.ordered_fragments [0 ]);
409+ ExpectFragmentInfoEq (rule.ordered_fragments [1 ],
410+ expected_rule.ordered_fragments [1 ]);
411+ }
412+
413+ class FragmentScheduleRuleParserInvalidSyntaxTest
414+ : public ::testing::TestWithParam<InvalidRuleTestParams> {};
415+
416+ TEST_P (FragmentScheduleRuleParserInvalidSyntaxTest, ParseInvalidRule) {
417+ const auto & params = GetParam ();
418+ FragmentScheduleRule rule;
419+ bool result = parseFragmentScheduleRule (params.invalid_rule_str , rule);
420+
421+ EXPECT_TRUE (result);
422+ }
423+
424+ INSTANTIATE_TEST_SUITE_P (
425+ FragmentScheduleRuleParser, FragmentScheduleRuleParserInvalidSyntaxTest,
426+ testing::Values (
427+ InvalidRuleTestParams{" MissingPrefix" ,
428+ " [FragmentInfo(origins=[\" f1\" (123)])->"
429+ " FragmentInfo(origins=[\" f2\" (456)])])" },
430+ InvalidRuleTestParams{
431+ " MissingArrow" ,
432+ " FragmentScheduleRule(ordered_fragments=[FragmentInfo(origins=["
433+ " \" f1\" (123)]) FragmentInfo(origins=[\" f2\" (456)])])" },
434+ InvalidRuleTestParams{
435+ " MissingClosingBrackets" ,
436+ " FragmentScheduleRule(ordered_fragments=[FragmentInfo(origins=["
437+ " \" f1\" (123)])->FragmentInfo(origins=[\" f2\" (456)])" }),
438+ [](const testing::TestParamInfo<
439+ FragmentScheduleRuleParserInvalidSyntaxTest::ParamType>& info) {
440+ return info.param .test_name ;
441+ });
442+
262443TEST (FragmentInfoMapInfoTest, IsEqual) {
263444 FragmentInfo info1 = MakeFragmentInfo ({MakeFragmentOrigin (" f1" , 123 )});
264445 FragmentInfo info2 = MakeFragmentInfo ({MakeFragmentOrigin (" f1" , 123 )});
0 commit comments