Skip to content

Commit 6bdcd77

Browse files
committed
Exclude complete entries at the moment
1 parent 6062163 commit 6bdcd77

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

src/nvidia_resiliency_ext/attribution/trace_analyzer/fr_attribution.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class Collective:
4848
pg_id: int
4949
op_id: int
5050
profiling_name: str
51+
state: str
5152
time_created_ns: int
5253
time_discovered_started_ns: int
5354
time_discovered_completed_ns: int
@@ -416,12 +417,14 @@ def get_correct_seq_id(collective):
416417
max_completed_collective_seq_id = -1
417418
max_enqueued_collective_seq_id = -1
418419
local_pg_map = dict()
420+
rank_id = None
419421
for c in collectives:
420422
rank_id = c.file_id
421-
pg_status = self.pg_status[rank_id][str(c.pg_id)]
422423
logger.debug(
423-
f"rank_id: {rank_id}, c.pg_id: {c.pg_id}, c.file_id: {c.file_id}, c.collective_seq_id: {c.collective_seq_id}, process_group: {process_group}"
424+
f"rank_id: {rank_id}, c.pg_id: {c.pg_id}, c.file_id: {c.file_id}, c.collective_seq_id: {c.collective_seq_id}, process_group: {process_group},"
425+
f"c.state: {c.state}"
424426
)
427+
pg_status = self.pg_status[rank_id][str(c.pg_id)]
425428
local_pg_map[rank_id] = c.pg_id
426429
if (
427430
pg_status['last_completed_collective']
@@ -441,6 +444,8 @@ def get_correct_seq_id(collective):
441444
# Ranks holding entries earlier than max_completed_collective_seq_id -> ranks failing to complete expected collectives
442445
rank_counts = defaultdict(list)
443446
for c in collectives:
447+
if c.state != 'scheduled':
448+
continue
444449
rank_counts['appeared'].append(c.file_id)
445450
if get_correct_seq_id(c) <= max_completed_collective_seq_id:
446451
rank_counts['mismatched'].append(c.file_id)
@@ -526,6 +531,7 @@ def pair_send_recv_operations():
526531
missing_ranks = set(global_ranks) - set(unique_ranks)
527532
missing_ranks = missing_ranks | set(map(int, mismatched_rank_counts.keys()))
528533

534+
correct_unique_ranks = set(unique_ranks) - missing_ranks
529535
logger.debug(f"missing_ranks: {missing_ranks}")
530536
process_group_str = process_group
531537

@@ -544,7 +550,7 @@ def pair_send_recv_operations():
544550
(size_str, 15, ''),
545551
(dtype, 10, ''),
546552
(total_unique_ranks, 10, 'd'),
547-
(','.join(map(str, unique_ranks)), 40, ''),
553+
(','.join(map(str, correct_unique_ranks)), 40, ''),
548554
(','.join(map(str, sorted(missing_ranks))), 40, ''),
549555
]
550556

@@ -827,7 +833,7 @@ def extract_collectives(data: Dict, file_id: str) -> List[Collective]:
827833
"""
828834
collectives = []
829835
for entry in data['entries']:
830-
if 'collective_seq_id' in entry:
836+
if 'collective_seq_id' in entry and entry['state'] == 'scheduled':
831837
collective = Collective(
832838
file_id=file_id,
833839
collective_seq_id=entry['collective_seq_id'],
@@ -843,6 +849,7 @@ def extract_collectives(data: Dict, file_id: str) -> List[Collective]:
843849
'time_discovered_completed_ns', entry['time_created_ns']
844850
),
845851
process_group=entry['process_group'],
852+
state=entry['state'],
846853
input_sizes=entry['input_sizes'],
847854
output_sizes=entry['output_sizes'],
848855
input_dtypes=entry['input_dtypes'],

0 commit comments

Comments
 (0)