Skip to content

Commit 6062163

Browse files
committed
Add minor fix to handle local to global pg id mapping
1 parent 1a24d0b commit 6062163

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

src/nvidia_resiliency_ext/attribution/trace_analyzer/fr_attribution.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -415,10 +415,14 @@ def get_correct_seq_id(collective):
415415
group_by_seq_id = defaultdict(list)
416416
max_completed_collective_seq_id = -1
417417
max_enqueued_collective_seq_id = -1
418-
418+
local_pg_map = dict()
419419
for c in collectives:
420420
rank_id = c.file_id
421-
pg_status = self.pg_status[rank_id][process_group]
421+
pg_status = self.pg_status[rank_id][str(c.pg_id)]
422+
logger.debug(
423+
f"rank_id: {rank_id}, c.pg_id: {c.pg_id}, c.file_id: {c.file_id}, c.collective_seq_id: {c.collective_seq_id}, process_group: {process_group}"
424+
)
425+
local_pg_map[rank_id] = c.pg_id
422426
if (
423427
pg_status['last_completed_collective']
424428
>= max_completed_collective_seq_id
@@ -433,20 +437,25 @@ def get_correct_seq_id(collective):
433437
logger.debug(
434438
f"max_enqueued_collective_seq_id: {max_enqueued_collective_seq_id}"
435439
)
436-
440+
local_pg_id = local_pg_map[rank_id]
437441
# Ranks holding entries earlier than max_completed_collective_seq_id -> ranks failing to complete expected collectives
438442
rank_counts = defaultdict(list)
439443
for c in collectives:
440444
rank_counts['appeared'].append(c.file_id)
441445
if get_correct_seq_id(c) <= max_completed_collective_seq_id:
442446
rank_counts['mismatched'].append(c.file_id)
443447
appeared_rank_counts = Counter(rank_counts['appeared'])
444-
445448
# Ranks with less number of enqueued collectives than max_enqueued_collective_seq_id -> host not making expected progress
446449
for rank_id in self.pg_configs[process_group]['ranks']:
447450
rank_id = str(rank_id)
448451
if (
449-
self.pg_status[str(rank_id)][process_group]['last_enqueued_collective']
452+
rank_id not in self.pg_status
453+
or str(local_pg_id) not in self.pg_status[rank_id]
454+
):
455+
continue
456+
457+
if (
458+
self.pg_status[rank_id][str(local_pg_id)]['last_enqueued_collective']
450459
< max_enqueued_collective_seq_id
451460
):
452461
rank_counts['mismatched'].append(rank_id)

0 commit comments

Comments
 (0)