Skip to content

Commit 6fd0e8e

Browse files
committed
Update fr_attribution.py to be used without llm
1 parent 283259f commit 6fd0e8e

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

src/nvidia_resiliency_ext/attribution/trace_analyzer/fr_attribution.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pathlib import Path
1212
from typing import Dict, List, Tuple, Union
1313

14-
from nvidia_resiliency_ext.attribution.base import NVRxAttribution
14+
from nvidia_resiliency_ext.attribution.base import AttributionState, NVRxAttribution
1515
from nvidia_resiliency_ext.attribution.utils import capture_stdout
1616

1717
logging.basicConfig(level=logging.INFO)
@@ -119,21 +119,25 @@ def __init__(self, args: argparse.Namespace):
119119

120120
# output handler to print the attribution results
121121
async def print_output(self, attribution_result: str):
122-
# print(attribution_result)
123-
for line in attribution_result.split('\n'):
124-
logger.info(line)
125122
hanging_ranks_list = []
126-
# If LLM is used, we assume the following format of the output
127123
if self.llm and self.args.llm_analyze:
124+
logger.info(attribution_result)
128125
hanging_ranks = re.search(r'.*hanging ranks: \{([^}]*)\}', attribution_result)
129126
if hanging_ranks is not None:
130127
# Parse the hanging ranks from the analysis output
131128
hanging_ranks_str = hanging_ranks.group(1).strip()
132129
hanging_ranks_list = list(map(int, hanging_ranks_str.split(',')))
133-
return hanging_ranks_list
130+
else:
131+
for idx, line in enumerate(attribution_result.split('\n')):
132+
line_list = line.split('|')
133+
if len(line_list) >= 5:
134+
logger.info(line)
135+
if idx >= 1:
136+
hanging_ranks_list.append(line_list[5])
137+
return f"hanging ranks: {hanging_ranks_list}", AttributionState.CONTINUE
134138

135139
# preprocess input to analyze the collective operations
136-
async def preprocess_FR_dumps(self, input_data: List[str]):
140+
async def preprocess_FR_dumps(self, input_data: List[str]) -> str:
137141
"""
138142
Analyzes the collective operations across multiple JSON files.
139143
@@ -235,7 +239,7 @@ def print_ranks_in_pgs(head_nodes, pg_dict, missing_or_completed="Missing"):
235239
analysis_output = output.getvalue()
236240
return analysis_output
237241

238-
async def collective_analysis(self, analysis_output: str, **kwargs):
242+
async def collective_analysis(self, analysis_output: str, **kwargs) -> str:
239243
"""
240244
Analyze the collective operations using a Large Language Model (LLM).
241245

0 commit comments

Comments
 (0)