Skip to content

Commit 283259f

Browse files
committed
Change any remaining print to logger
1 parent 6bdcd77 commit 283259f

File tree

1 file changed

+19
-25
lines changed

1 file changed

+19
-25
lines changed

src/nvidia_resiliency_ext/attribution/trace_analyzer/fr_attribution.py

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -233,12 +233,7 @@ def print_ranks_in_pgs(head_nodes, pg_dict, missing_or_completed="Missing"):
233233
if head_nodes_completed:
234234
print_ranks_in_pgs(head_nodes_completed, completed_pg, "Completed")
235235
analysis_output = output.getvalue()
236-
attribution_kwargs = {
237-
"model": self.args.model,
238-
"scheduling_order": self.args.scheduling_order_file,
239-
"verbose": self.args.verbose,
240-
}
241-
return analysis_output, attribution_kwargs
236+
return analysis_output
242237

243238
async def collective_analysis(self, analysis_output: str, **kwargs):
244239
"""
@@ -258,7 +253,7 @@ async def collective_analysis(self, analysis_output: str, **kwargs):
258253
Note:
259254
Requires the NVIDIA_API_KEY environment variable to be set
260255
"""
261-
result = analysis_output[0]
256+
result = analysis_output
262257
if self.args.llm_analyze:
263258
model = kwargs["model"]
264259
verbose = kwargs["verbose"]
@@ -331,14 +326,14 @@ def analyze_matches(self, verbose: bool = False):
331326
Args:
332327
verbose (bool): Whether to include more detailed analysis in the output
333328
"""
334-
print("\n=== Collective Operations Analysis ===\n")
329+
logger.info("\n=== Collective Operations Analysis ===\n")
335330

336331
if verbose:
337-
print("Files processed:")
332+
logger.info("Files processed:")
338333
for rank_id in sorted(self.collectives_by_file.keys()):
339334
count = len(self.collectives_by_file[rank_id])
340-
print(f" {rank_id}: {count} collectives")
341-
print()
335+
logger.info(f" {rank_id}: {count} collectives")
336+
logger.info("")
342337

343338
# Extract unique sub-group types from the data
344339
group_types = set()
@@ -356,9 +351,9 @@ def analyze_matches(self, verbose: bool = False):
356351
# If no group types were found, use default ones
357352
if not group_types:
358353
group_types = ["TENSOR_MODEL", "PIPELINE_MODEL", "DATA_PARALLEL"]
359-
print("No sub-group types found in data. Using default group types.")
354+
logger.info("No sub-group types found in data. Using default group types.")
360355
else:
361-
print(f"Found group types: {', '.join(group_types)}")
356+
logger.info(f"Found group types: {', '.join(group_types)}")
362357

363358
# Categorize collective groups by type
364359
categorized_groups = {group_type: [] for group_type in group_types}
@@ -381,7 +376,7 @@ def analyze_matches(self, verbose: bool = False):
381376
missing_pg = defaultdict(list)
382377
for group_type in group_types:
383378
if categorized_groups[group_type]:
384-
print(f"\n=== {group_type} Collectives ===\n")
379+
logger.info(f"=== {group_type} Collectives ===")
385380

386381
# Headers for this section
387382
headers = [
@@ -396,8 +391,8 @@ def analyze_matches(self, verbose: bool = False):
396391
]
397392

398393
header_line = " ".join(f"{name:>{width}}" for name, width in headers)
399-
print(header_line)
400-
print("-" * len(header_line))
394+
logger.info(header_line)
395+
logger.info("-" * len(header_line))
401396

402397
def get_correct_seq_id(collective):
403398
if (
@@ -562,18 +557,18 @@ def pair_send_recv_operations():
562557
continue
563558
else:
564559
missing_pg[(int)(parsed_row[0])].append(parsed_row)
565-
print(row)
560+
logger.info(row)
566561

567562
# Print detailed rank count distribution
568563
if verbose:
569-
print(f" Rank count distribution for {process_group_str}:")
564+
logger.info(f" Rank count distribution for {process_group_str}:")
570565
for rank, count in sorted(appeared_rank_counts.items()):
571-
print(f" Rank {rank}: {count} occurrences")
566+
logger.info(f" Rank {rank}: {count} occurrences")
572567

573568
# Print operation type distribution with paired send/recv analysis
574-
print(" Operation type distribution:")
569+
logger.info(" Operation type distribution:")
575570
# Print paired send/recv operations
576-
print(" Send/Receive pairs (src->dst):")
571+
logger.info(" Send/Receive pairs (src->dst):")
577572

578573
# Print each pair with send and recv counts
579574
for src, dst in all_pairs:
@@ -586,17 +581,16 @@ def pair_send_recv_operations():
586581
else:
587582
imbalance = ""
588583

589-
print(
584+
logger.info(
590585
f" {global_ranks[int(src)]}->{global_ranks[int(dst)]}: {send_count} sends, {recv_count} recvs{imbalance}"
591586
)
592587

593588
# Print other operations
594589
if other_ops:
595-
print(" Other operations:")
590+
logger.info(" Other operations:")
596591
for op, count in sorted(other_ops.items(), key=lambda x: (-x[1], x[0])):
597-
print(f" {op}: {count}")
592+
logger.info(f" {op}: {count}")
598593

599-
print() # Add an empty line for better readability
600594
return completed_pg, missing_pg
601595

602596
def group_pgs(self, pgs: Dict[str, List[str]]) -> Dict[int, List[int]]:

0 commit comments

Comments
 (0)