|
11 | 11 | from pathlib import Path |
12 | 12 | from typing import Dict, List, Tuple, Union |
13 | 13 |
|
14 | | -from nvidia_resiliency_ext.attribution.base import NVRxAttribution |
| 14 | +from nvidia_resiliency_ext.attribution.base import AttributionState, NVRxAttribution |
15 | 15 | from nvidia_resiliency_ext.attribution.utils import capture_stdout |
16 | 16 |
|
17 | 17 | logging.basicConfig(level=logging.INFO) |
@@ -119,21 +119,25 @@ def __init__(self, args: argparse.Namespace): |
119 | 119 |
|
120 | 120 | # output handler to print the attribution results |
121 | 121 | async def print_output(self, attribution_result: str): |
122 | | - # print(attribution_result) |
123 | | - for line in attribution_result.split('\n'): |
124 | | - logger.info(line) |
125 | 122 | hanging_ranks_list = [] |
126 | | - # If LLM is used, we assume the following format of the output |
127 | 123 | if self.llm and self.args.llm_analyze: |
| 124 | + logger.info(attribution_result) |
128 | 125 | hanging_ranks = re.search(r'.*hanging ranks: \{([^}]*)\}', attribution_result) |
129 | 126 | if hanging_ranks is not None: |
130 | 127 | # Parse the hanging ranks from the analysis output |
131 | 128 | hanging_ranks_str = hanging_ranks.group(1).strip() |
132 | 129 | hanging_ranks_list = list(map(int, hanging_ranks_str.split(','))) |
133 | | - return hanging_ranks_list |
| 130 | + else: |
| 131 | + for idx, line in enumerate(attribution_result.split('\n')): |
| 132 | + line_list = line.split('|') |
| 133 | + if len(line_list) >= 5: |
| 134 | + logger.info(line) |
| 135 | + if idx >= 1: |
| 136 | + hanging_ranks_list.append(line_list[5]) |
| 137 | + return f"hanging ranks: {hanging_ranks_list}", AttributionState.CONTINUE |
134 | 138 |
|
135 | 139 | # preprocess input to analyze the collective operations |
136 | | - async def preprocess_FR_dumps(self, input_data: List[str]): |
| 140 | + async def preprocess_FR_dumps(self, input_data: List[str]) -> str: |
137 | 141 | """ |
138 | 142 | Analyzes the collective operations across multiple JSON files. |
139 | 143 |
|
@@ -235,7 +239,7 @@ def print_ranks_in_pgs(head_nodes, pg_dict, missing_or_completed="Missing"): |
235 | 239 | analysis_output = output.getvalue() |
236 | 240 | return analysis_output |
237 | 241 |
|
238 | | - async def collective_analysis(self, analysis_output: str, **kwargs): |
| 242 | + async def collective_analysis(self, analysis_output: str, **kwargs) -> str: |
239 | 243 | """ |
240 | 244 | Analyze the collective operations using a Large Language Model (LLM). |
241 | 245 |
|
|
0 commit comments