@@ -233,12 +233,7 @@ def print_ranks_in_pgs(head_nodes, pg_dict, missing_or_completed="Missing"):
233233 if head_nodes_completed :
234234 print_ranks_in_pgs (head_nodes_completed , completed_pg , "Completed" )
235235 analysis_output = output .getvalue ()
236- attribution_kwargs = {
237- "model" : self .args .model ,
238- "scheduling_order" : self .args .scheduling_order_file ,
239- "verbose" : self .args .verbose ,
240- }
241- return analysis_output , attribution_kwargs
236+ return analysis_output
242237
243238 async def collective_analysis (self , analysis_output : str , ** kwargs ):
244239 """
@@ -258,7 +253,7 @@ async def collective_analysis(self, analysis_output: str, **kwargs):
258253 Note:
259254 Requires the NVIDIA_API_KEY environment variable to be set
260255 """
261- result = analysis_output [ 0 ]
256+ result = analysis_output
262257 if self .args .llm_analyze :
263258 model = kwargs ["model" ]
264259 verbose = kwargs ["verbose" ]
@@ -331,14 +326,14 @@ def analyze_matches(self, verbose: bool = False):
331326 Args:
332327 verbose (bool): Whether to include more detailed analysis in the output
333328 """
334- print ("\n === Collective Operations Analysis ===\n " )
329+ logger . info ("\n === Collective Operations Analysis ===\n " )
335330
336331 if verbose :
337- print ("Files processed:" )
332+ logger . info ("Files processed:" )
338333 for rank_id in sorted (self .collectives_by_file .keys ()):
339334 count = len (self .collectives_by_file [rank_id ])
340- print (f" { rank_id } : { count } collectives" )
341- print ( )
335+ logger . info (f" { rank_id } : { count } collectives" )
336+ logger . info ( "" )
342337
343338 # Extract unique sub-group types from the data
344339 group_types = set ()
@@ -356,9 +351,9 @@ def analyze_matches(self, verbose: bool = False):
356351 # If no group types were found, use default ones
357352 if not group_types :
358353 group_types = ["TENSOR_MODEL" , "PIPELINE_MODEL" , "DATA_PARALLEL" ]
359- print ("No sub-group types found in data. Using default group types." )
354+ logger . info ("No sub-group types found in data. Using default group types." )
360355 else :
361- print (f"Found group types: { ', ' .join (group_types )} " )
356+ logger . info (f"Found group types: { ', ' .join (group_types )} " )
362357
363358 # Categorize collective groups by type
364359 categorized_groups = {group_type : [] for group_type in group_types }
@@ -381,7 +376,7 @@ def analyze_matches(self, verbose: bool = False):
381376 missing_pg = defaultdict (list )
382377 for group_type in group_types :
383378 if categorized_groups [group_type ]:
384- print (f"\n === { group_type } Collectives ===\n " )
379+ logger . info (f"=== { group_type } Collectives ===" )
385380
386381 # Headers for this section
387382 headers = [
@@ -396,8 +391,8 @@ def analyze_matches(self, verbose: bool = False):
396391 ]
397392
398393 header_line = " " .join (f"{ name :>{width }} " for name , width in headers )
399- print (header_line )
400- print ("-" * len (header_line ))
394+ logger . info (header_line )
395+ logger . info ("-" * len (header_line ))
401396
402397 def get_correct_seq_id (collective ):
403398 if (
@@ -562,18 +557,18 @@ def pair_send_recv_operations():
562557 continue
563558 else :
564559 missing_pg [(int )(parsed_row [0 ])].append (parsed_row )
565- print (row )
560+ logger . info (row )
566561
567562 # Print detailed rank count distribution
568563 if verbose :
569- print (f" Rank count distribution for { process_group_str } :" )
564+ logger . info (f" Rank count distribution for { process_group_str } :" )
570565 for rank , count in sorted (appeared_rank_counts .items ()):
571- print (f" Rank { rank } : { count } occurrences" )
566+ logger . info (f" Rank { rank } : { count } occurrences" )
572567
573568 # Print operation type distribution with paired send/recv analysis
574- print (" Operation type distribution:" )
569+ logger . info (" Operation type distribution:" )
575570 # Print paired send/recv operations
576- print (" Send/Receive pairs (src->dst):" )
571+ logger . info (" Send/Receive pairs (src->dst):" )
577572
578573 # Print each pair with send and recv counts
579574 for src , dst in all_pairs :
@@ -586,17 +581,16 @@ def pair_send_recv_operations():
586581 else :
587582 imbalance = ""
588583
589- print (
584+ logger . info (
590585 f" { global_ranks [int (src )]} ->{ global_ranks [int (dst )]} : { send_count } sends, { recv_count } recvs{ imbalance } "
591586 )
592587
593588 # Print other operations
594589 if other_ops :
595- print (" Other operations:" )
590+ logger . info (" Other operations:" )
596591 for op , count in sorted (other_ops .items (), key = lambda x : (- x [1 ], x [0 ])):
597- print (f" { op } : { count } " )
592+ logger . info (f" { op } : { count } " )
598593
599- print () # Add an empty line for better readability
600594 return completed_pg , missing_pg
601595
602596 def group_pgs (self , pgs : Dict [str , List [str ]]) -> Dict [int , List [int ]]:
0 commit comments