@@ -59,11 +59,11 @@ FragmentOrigin MakeFragmentOrigin(const std::string& computation_name,
5959}
6060
6161FragmentInfo MakeFragmentInfo (
62- const std::vector<FragmentOrigin>& origins,
62+ const std::vector<FragmentOrigin>& origins, const std::string& mesh_name,
6363 std::optional<int > stage_id = std::nullopt ,
6464 std::optional<int > call_counter = std::nullopt ,
6565 std::optional<SplitFragmentType> split_type = std::nullopt ) {
66- return {origins, stage_id, call_counter, split_type};
66+ return {origins, stage_id, call_counter, split_type, mesh_name };
6767}
6868
6969FragmentMergeRule MakeFragmentMergeRule (
@@ -139,6 +139,7 @@ TEST(GetFragmentInfoTest, GetFragmentInfo) {
139139 fragment_info,
140140 MakeFragmentInfo (
141141 {MakeFragmentOrigin (" f1" , 123 ), MakeFragmentOrigin (" f2" , 123 )},
142+ /* mesh_name=*/ " m1" ,
142143 /* stage_id=*/ std::nullopt ,
143144 /* call_counter=*/ std::nullopt , /* split_type=*/ std::nullopt ));
144145}
@@ -187,12 +188,14 @@ INSTANTIATE_TEST_SUITE_P(
187188 testing::Values (
188189 SetFragmentInfoTestParams{
189190 " WithStageAndCallCounter" ,
190- MakeFragmentInfo ({MakeFragmentOrigin (" f3" , 456 )}, /* stage_id=*/ 1 ,
191- /* call_counter=*/ 2 , /* split_type=*/ std::nullopt )},
191+ MakeFragmentInfo ({MakeFragmentOrigin (" f3" , 456 )},
192+ /* mesh_name=*/ " m1" ,
193+ /* stage_id=*/ 1 , /* call_counter=*/ 2 ,
194+ /* split_type=*/ std::nullopt )},
192195 SetFragmentInfoTestParams{
193196 " WithWeightGradient" ,
194197 MakeFragmentInfo (
195- {MakeFragmentOrigin (" f4" , 789 )},
198+ {MakeFragmentOrigin (" f4" , 789 )}, /* mesh_name= */ " m1 " ,
196199 /* stage_id=*/ std::nullopt ,
197200 /* call_counter=*/ std::nullopt ,
198201 /* split_type=*/ SplitFragmentType::kDropTransferred )}),
@@ -224,6 +227,7 @@ TEST(SetFragmentInfoTest, RemovesSplitDropTransferred) {
224227
225228 IRRewriter rewriter (&context);
226229 FragmentInfo info = MakeFragmentInfo ({MakeFragmentOrigin (" f1" , 0 )},
230+ /* mesh_name=*/ " m1" ,
227231 /* stage_id=*/ std::nullopt ,
228232 /* call_counter=*/ std::nullopt ,
229233 /* split_type=*/ std::nullopt );
@@ -258,66 +262,68 @@ INSTANTIATE_TEST_SUITE_P(
258262 " NoSplitType" ,
259263 MakeFragmentInfo ({MakeFragmentOrigin (" f1" , 123 ),
260264 MakeFragmentOrigin (" f2" , 456 )},
261- /* stage_id =*/ 1 , /* call_counter =*/ 2 ,
262- /* split_type=*/ std::nullopt ),
265+ /* mesh_name =*/ " m1 " , /* stage_id =*/ 1 ,
266+ /* call_counter= */ 2 , /* split_type=*/ std::nullopt ),
263267 " FragmentInfo(origins=[\" f1\" (123),\" f2\" (456)],stage=1,call_"
264- " counter=2)" },
268+ " counter=2,mesh_name= \" m1 \" )" },
265269 PrintFragmentInfoTestParams{
266270 " WithSplitTypeDropTransferred" ,
267271 MakeFragmentInfo (
268- {MakeFragmentOrigin (" f1" , 123 )}, /* stage_id =*/ 1 ,
269- /* call_counter=*/ 2 ,
272+ {MakeFragmentOrigin (" f1" , 123 )}, /* mesh_name =*/ " m1 " ,
273+ /* stage_id= */ 1 , /* call_counter=*/ 2 ,
270274 /* split_type=*/ SplitFragmentType::kDropTransferred ),
271275 " FragmentInfo(origins=[\" f1\" (123)],stage=1,call_counter=2,"
272- " split_type=kDropTransferred)" },
276+ " split_type=kDropTransferred,mesh_name= \" m1 \" )" },
273277 PrintFragmentInfoTestParams{
274278 " WithSplitTypeKeepTransferred" ,
275279 MakeFragmentInfo (
276- {MakeFragmentOrigin (" f1" , 123 )}, /* stage_id =*/ 1 ,
277- /* call_counter=*/ 2 ,
280+ {MakeFragmentOrigin (" f1" , 123 )}, /* mesh_name =*/ " m1 " ,
281+ /* stage_id= */ 1 , /* call_counter=*/ 2 ,
278282 /* split_type=*/ SplitFragmentType::kKeepTransferred ),
279283 " FragmentInfo(origins=[\" f1\" (123)],stage=1,call_counter=2,"
280- " split_type=kKeepTransferred)" },
284+ " split_type=kKeepTransferred,mesh_name= \" m1 \" )" },
281285 PrintFragmentInfoTestParams{
282286 " OnlyRequiredFields" ,
283- MakeFragmentInfo ({MakeFragmentOrigin (" f1" , 123 )}),
284- " FragmentInfo(origins=[\" f1\" (123)])" }),
287+ MakeFragmentInfo ({MakeFragmentOrigin (" f1" , 123 )},
288+ /* mesh_name=*/ " m1" ),
289+ " FragmentInfo(origins=[\" f1\" (123)],mesh_name=\" m1\" )" }),
285290 [](const testing::TestParamInfo<PrintFragmentInfoTest::ParamType>& info) {
286291 return info.param .test_name ;
287292 });
288293
289294TEST (FragmentMergeRule, PrintFragmentMergeRule) {
290295 FragmentMergeRule rule = MakeFragmentMergeRule (
291- {MakeFragmentInfo ({MakeFragmentOrigin (" f1" , 123 )}, /* stage_id=*/ 1 ),
292- MakeFragmentInfo ({MakeFragmentOrigin (" f2" , 456 )}, /* stage_id=*/ 1 )},
296+ {MakeFragmentInfo ({MakeFragmentOrigin (" f1" , 123 )}, /* mesh_name=*/ " m1" ,
297+ /* stage_id=*/ 1 ),
298+ MakeFragmentInfo ({MakeFragmentOrigin (" f2" , 456 )}, /* mesh_name=*/ " m1" ,
299+ /* stage_id=*/ 1 )},
293300 MakeFragmentInfo (
294301 {MakeFragmentOrigin (" f1" , 123 ), MakeFragmentOrigin (" f2" , 456 )},
295- /* stage_id=*/ 1 , /* call_counter=*/ std::nullopt ,
302+ /* mesh_name= */ " m1 " , /* stage_id=*/ 1 , /* call_counter=*/ std::nullopt ,
296303 /* split_type=*/ std::nullopt ));
297304 std::string str;
298305 llvm::raw_string_ostream os (str);
299306 os << rule;
300- EXPECT_THAT (str, Eq (" FragmentMergeRule(sources=["
301- " FragmentInfo(origins=[\" f1\" (123)],stage=1),"
302- " FragmentInfo(origins=[\" f2\" (456)],stage=1)],"
303- " target=FragmentInfo(origins=["
304- " \" f1\" (123),\" f2\" (456)],stage=1))" ));
307+ EXPECT_THAT (
308+ str, Eq (" FragmentMergeRule(sources=["
309+ " FragmentInfo(origins=[\" f1\" (123)],stage=1,mesh_name=\" m1\" ),"
310+ " FragmentInfo(origins=[\" f2\" (456)],stage=1,mesh_name=\" m1\" )],"
311+ " target=FragmentInfo(origins=["
312+ " \" f1\" (123),\" f2\" (456)],stage=1,mesh_name=\" m1\" ))" ));
305313}
306314
307315TEST (FragmentMergeRuleParser, ParseValidRule) {
308316 FragmentMergeRule expected_rule = MakeFragmentMergeRule (
309- {MakeFragmentInfo ({MakeFragmentOrigin (" f1" , 123 )}, /* stage_id =*/ 1 ,
310- /* call_counter=*/ std::nullopt ,
317+ {MakeFragmentInfo ({MakeFragmentOrigin (" f1" , 123 )}, /* mesh_name =*/ " m1 " ,
318+ /* stage_id= */ 1 , /* call_counter=*/ std::nullopt ,
311319 /* split_type=*/ std::nullopt ),
312- MakeFragmentInfo ({MakeFragmentOrigin (" f2" , 456 )},
313- /* stage_id=*/ 1 ,
314- /* call_counter=*/ std::nullopt ,
320+ MakeFragmentInfo ({MakeFragmentOrigin (" f2" , 456 )}, /* mesh_name=*/ " m1" ,
321+ /* stage_id=*/ 1 , /* call_counter=*/ std::nullopt ,
315322 /* split_type=*/ SplitFragmentType::kDropTransferred )},
316323 MakeFragmentInfo (
317324 {MakeFragmentOrigin (" f1" , 123 ), MakeFragmentOrigin (" f2" , 456 )},
318- /* stage_id=*/ 1 ,
319- /* call_counter=*/ std::nullopt ,
320- /* split_type=*/ std::nullopt ));
325+ /* mesh_name=*/ " m1" , /* stage_id=*/ 1 ,
326+ /* call_counter=*/ std::nullopt , /* split_type=*/ std::nullopt ));
321327 // We first construct the rule and print it to a string. Then we parse that
322328 // string to ensure that the printed form of a rule is directly compatible
323329 // with the format the parser expects.
@@ -370,24 +376,27 @@ INSTANTIATE_TEST_SUITE_P(
370376
371377TEST (FragmentScheduleRule, PrintFragmentScheduleRule) {
372378 FragmentScheduleRule rule = MakeFragmentScheduleRule (
373- {MakeFragmentInfo ({MakeFragmentOrigin (" f1" , 123 )}, /* stage_id=*/ 1 ),
374- MakeFragmentInfo ({MakeFragmentOrigin (" f2" , 456 )}, /* stage_id=*/ 2 )});
379+ {MakeFragmentInfo ({MakeFragmentOrigin (" f1" , 123 )}, /* mesh_name=*/ " m1" ,
380+ /* stage_id=*/ 1 ),
381+ MakeFragmentInfo ({MakeFragmentOrigin (" f2" , 456 )}, /* mesh_name=*/ " m1" ,
382+ /* stage_id=*/ 2 )});
375383 std::string str;
376384 llvm::raw_string_ostream os (str);
377385 os << rule;
378- EXPECT_THAT (str, Eq (" FragmentScheduleRule(ordered_fragments=["
379- " FragmentInfo(origins=[\" f1\" (123)],stage=1)->"
380- " FragmentInfo(origins=[\" f2\" (456)],stage=2)])" ));
386+ EXPECT_THAT (
387+ str,
388+ Eq (" FragmentScheduleRule(ordered_fragments=["
389+ " FragmentInfo(origins=[\" f1\" (123)],stage=1,mesh_name=\" m1\" )->"
390+ " FragmentInfo(origins=[\" f2\" (456)],stage=2,mesh_name=\" m1\" )])" ));
381391}
382392
383393TEST (FragmentScheduleRuleParser, ParseValidRule) {
384394 FragmentScheduleRule expected_rule = MakeFragmentScheduleRule (
385- {MakeFragmentInfo ({MakeFragmentOrigin (" f1" , 123 )}, /* stage_id =*/ 1 ,
386- /* call_counter=*/ std::nullopt ,
395+ {MakeFragmentInfo ({MakeFragmentOrigin (" f1" , 123 )}, /* mesh_name =*/ " m1 " ,
396+ /* stage_id= */ 1 , /* call_counter=*/ std::nullopt ,
387397 /* split_type=*/ std::nullopt ),
388- MakeFragmentInfo ({MakeFragmentOrigin (" f2" , 456 )},
389- /* stage_id=*/ 1 ,
390- /* call_counter=*/ std::nullopt ,
398+ MakeFragmentInfo ({MakeFragmentOrigin (" f2" , 456 )}, /* mesh_name=*/ " m1" ,
399+ /* stage_id=*/ 1 , /* call_counter=*/ std::nullopt ,
391400 /* split_type=*/ SplitFragmentType::kDropTransferred )});
392401 // We first construct the rule and print it to a string. Then we parse that
393402 // string to ensure that the printed form of a rule is directly compatible
@@ -439,18 +448,18 @@ INSTANTIATE_TEST_SUITE_P(
439448 });
440449
441450TEST (FragmentInfoMapInfoTest, IsEqual) {
442- FragmentInfo info1 = MakeFragmentInfo ({MakeFragmentOrigin (" f1" , 123 )});
443- FragmentInfo info2 = MakeFragmentInfo ({MakeFragmentOrigin (" f1" , 123 )});
444- FragmentInfo info3 = MakeFragmentInfo ({MakeFragmentOrigin (" f2" , 456 )});
451+ FragmentInfo info1 = MakeFragmentInfo ({MakeFragmentOrigin (" f1" , 123 )}, " m1 " );
452+ FragmentInfo info2 = MakeFragmentInfo ({MakeFragmentOrigin (" f1" , 123 )}, " m1 " );
453+ FragmentInfo info3 = MakeFragmentInfo ({MakeFragmentOrigin (" f2" , 456 )}, " m1 " );
445454
446455 EXPECT_TRUE (FragmentInfoMapInfo::isEqual (info1, info2));
447456 EXPECT_FALSE (FragmentInfoMapInfo::isEqual (info1, info3));
448457}
449458
450459TEST (FragmentInfoMapInfoTest, GetHashValue) {
451- FragmentInfo info1 = MakeFragmentInfo ({MakeFragmentOrigin (" f1" , 123 )});
452- FragmentInfo info2 = MakeFragmentInfo ({MakeFragmentOrigin (" f1" , 123 )});
453- FragmentInfo info3 = MakeFragmentInfo ({MakeFragmentOrigin (" f2" , 456 )});
460+ FragmentInfo info1 = MakeFragmentInfo ({MakeFragmentOrigin (" f1" , 123 )}, " m1 " );
461+ FragmentInfo info2 = MakeFragmentInfo ({MakeFragmentOrigin (" f1" , 123 )}, " m1 " );
462+ FragmentInfo info3 = MakeFragmentInfo ({MakeFragmentOrigin (" f2" , 456 )}, " m1 " );
454463
455464 EXPECT_EQ (FragmentInfoMapInfo::getHashValue (info1),
456465 FragmentInfoMapInfo::getHashValue (info2));
@@ -462,7 +471,7 @@ TEST(FragmentInfoMapInfoTest, GetHashValue) {
462471TEST (FragmentInfoMapInfoTest, SpecialKeys) {
463472 FragmentInfo emptyKey = FragmentInfoMapInfo::getEmptyKey ();
464473 FragmentInfo tombstoneKey = FragmentInfoMapInfo::getTombstoneKey ();
465- FragmentInfo info1 = MakeFragmentInfo ({MakeFragmentOrigin (" f1" , 123 )});
474+ FragmentInfo info1 = MakeFragmentInfo ({MakeFragmentOrigin (" f1" , 123 )}, " m1 " );
466475
467476 EXPECT_FALSE (FragmentInfoMapInfo::isEqual (emptyKey, info1));
468477 EXPECT_FALSE (FragmentInfoMapInfo::isEqual (tombstoneKey, info1));
@@ -472,8 +481,8 @@ TEST(FragmentInfoMapInfoTest, SpecialKeys) {
472481TEST (FragmentInfoMapInfoTest, DenseMapIntegration) {
473482 llvm::DenseMap<FragmentInfo, int , FragmentInfoMapInfo> map;
474483
475- FragmentInfo info1 = MakeFragmentInfo ({MakeFragmentOrigin (" f1" , 123 )});
476- FragmentInfo info2 = MakeFragmentInfo ({MakeFragmentOrigin (" f2" , 456 )});
484+ FragmentInfo info1 = MakeFragmentInfo ({MakeFragmentOrigin (" f1" , 123 )}, " m1 " );
485+ FragmentInfo info2 = MakeFragmentInfo ({MakeFragmentOrigin (" f2" , 456 )}, " m1 " );
477486
478487 map[info1] = 1 ;
479488 map[info2] = 2 ;
@@ -482,7 +491,8 @@ TEST(FragmentInfoMapInfoTest, DenseMapIntegration) {
482491 EXPECT_EQ (map[info1], 1 );
483492 EXPECT_EQ (map[info2], 2 );
484493
485- FragmentInfo info1_copy = MakeFragmentInfo ({MakeFragmentOrigin (" f1" , 123 )});
494+ FragmentInfo info1_copy =
495+ MakeFragmentInfo ({MakeFragmentOrigin (" f1" , 123 )}, " m1" );
486496 EXPECT_TRUE (map.contains (info1_copy));
487497 EXPECT_EQ (map[info1_copy], 1 );
488498
0 commit comments