@@ -48,6 +48,7 @@ class Collective:
4848 pg_id : int
4949 op_id : int
5050 profiling_name : str
51+ state : str
5152 time_created_ns : int
5253 time_discovered_started_ns : int
5354 time_discovered_completed_ns : int
@@ -416,12 +417,14 @@ def get_correct_seq_id(collective):
416417 max_completed_collective_seq_id = - 1
417418 max_enqueued_collective_seq_id = - 1
418419 local_pg_map = dict ()
420+ rank_id = None
419421 for c in collectives :
420422 rank_id = c .file_id
421- pg_status = self .pg_status [rank_id ][str (c .pg_id )]
422423 logger .debug (
423- f"rank_id: { rank_id } , c.pg_id: { c .pg_id } , c.file_id: { c .file_id } , c.collective_seq_id: { c .collective_seq_id } , process_group: { process_group } "
424+ f"rank_id: { rank_id } , c.pg_id: { c .pg_id } , c.file_id: { c .file_id } , c.collective_seq_id: { c .collective_seq_id } , process_group: { process_group } ,"
425+ f"c.state: { c .state } "
424426 )
427+ pg_status = self .pg_status [rank_id ][str (c .pg_id )]
425428 local_pg_map [rank_id ] = c .pg_id
426429 if (
427430 pg_status ['last_completed_collective' ]
@@ -441,6 +444,8 @@ def get_correct_seq_id(collective):
441444 # Ranks holding entries earlier than max_completed_collective_seq_id -> ranks failing to complete expected collectives
442445 rank_counts = defaultdict (list )
443446 for c in collectives :
447+ if c .state != 'scheduled' :
448+ continue
444449 rank_counts ['appeared' ].append (c .file_id )
445450 if get_correct_seq_id (c ) <= max_completed_collective_seq_id :
446451 rank_counts ['mismatched' ].append (c .file_id )
@@ -526,6 +531,7 @@ def pair_send_recv_operations():
526531 missing_ranks = set (global_ranks ) - set (unique_ranks )
527532 missing_ranks = missing_ranks | set (map (int , mismatched_rank_counts .keys ()))
528533
534+ correct_unique_ranks = set (unique_ranks ) - missing_ranks
529535 logger .debug (f"missing_ranks: { missing_ranks } " )
530536 process_group_str = process_group
531537
@@ -544,7 +550,7 @@ def pair_send_recv_operations():
544550 (size_str , 15 , '' ),
545551 (dtype , 10 , '' ),
546552 (total_unique_ranks , 10 , 'd' ),
547- (',' .join (map (str , unique_ranks )), 40 , '' ),
553+ (',' .join (map (str , correct_unique_ranks )), 40 , '' ),
548554 (',' .join (map (str , sorted (missing_ranks ))), 40 , '' ),
549555 ]
550556
@@ -827,7 +833,7 @@ def extract_collectives(data: Dict, file_id: str) -> List[Collective]:
827833 """
828834 collectives = []
829835 for entry in data ['entries' ]:
830- if 'collective_seq_id' in entry :
836+ if 'collective_seq_id' in entry and entry [ 'state' ] == 'scheduled' :
831837 collective = Collective (
832838 file_id = file_id ,
833839 collective_seq_id = entry ['collective_seq_id' ],
@@ -843,6 +849,7 @@ def extract_collectives(data: Dict, file_id: str) -> List[Collective]:
843849 'time_discovered_completed_ns' , entry ['time_created_ns' ]
844850 ),
845851 process_group = entry ['process_group' ],
852+ state = entry ['state' ],
846853 input_sizes = entry ['input_sizes' ],
847854 output_sizes = entry ['output_sizes' ],
848855 input_dtypes = entry ['input_dtypes' ],
0 commit comments