1- import argparse
21import os
3- import sys
42from pathlib import Path
5-
3+
64from primus .core .launcher .parser import PrimusParser
75from primus .core .projection .training_config import convert_primus_config_to_projection_config
86from primus .core .projection .module_profilers .language_model import build_profiler , get_language_model_profiler_spec
119def print_profiler_hierarchy (profiler , batch_size , seq_len , rank = None , name = "root" , depth = 0 , visited = None ):
1210 """
1311 Recursively print the profiler hierarchy with num_params and activation_memory for each component.
14-
12+
1513 Args:
1614 profiler: The profiler instance to print
1715 batch_size: Batch size for activation memory calculation
@@ -23,15 +21,15 @@ def print_profiler_hierarchy(profiler, batch_size, seq_len, rank=None, name="roo
2321 """
2422 if visited is None :
2523 visited = set ()
26-
24+
2725 # Avoid infinite recursion if profilers reference each other
2826 profiler_id = id (profiler )
2927 if profiler_id in visited :
3028 return
3129 visited .add (profiler_id )
32-
30+
3331 indent = " " * depth
34-
32+
3533 # Calculate metrics for this profiler
3634 try :
3735 if depth == 0 :
@@ -44,7 +42,7 @@ def print_profiler_hierarchy(profiler, batch_size, seq_len, rank=None, name="roo
4442 print (f"{ indent } [{ name } ]" )
4543 print (f"{ indent } Params: { num_params / 1e9 :.6f} Billion ({ num_params :,} )" )
4644 print (f"{ indent } Activation Memory: { activation_mem / 1024 / 1024 / 1024 :.4f} GB" )
47-
45+
4846 # Recursively process sub_profilers if they exist
4947 if hasattr (profiler , 'sub_profilers' ) and profiler .sub_profilers :
5048 for sub_name , sub_profiler in profiler .sub_profilers .items ():
@@ -75,16 +73,16 @@ def launch_projection_from_cli(args, overrides):
7573 seq_len = training_config .runtime_config .sequence_length
7674 batch_size = training_config .runtime_config .micro_batch_size
7775 rank = int (os .getenv ('RANK' , '0' ))
78-
76+
7977 # Print recursive profiler hierarchy with detailed breakdown
8078 print ("\n " + "=" * 100 )
8179 print (f"[Primus:Projection] Component-wise Profiling Results (Rank { rank } ):" )
8280 print ("=" * 100 )
8381 print ()
84-
82+
8583 # Print the complete hierarchy recursively
8684 print_profiler_hierarchy (model_profiler , batch_size , seq_len , rank = rank , name = "LanguageModelProfiler" , depth = 0 )
87-
85+
8886 # Get overall totals from the model profiler for this rank
8987 num_params = model_profiler .estimated_num_params (rank = rank )
9088 activation_memory = model_profiler .estimated_activation_memory (batch_size , seq_len )
@@ -98,4 +96,4 @@ def launch_projection_from_cli(args, overrides):
9896 f"{ activation_memory / 1024 / 1024 / 1024 :.4f} GB" )
9997 print (f" Projected Total Memory: "
10098 f"{ (num_params * num_bytes_per_param + activation_memory ) / 1024 / 1024 / 1024 :.4f} GB" )
101- print ("=" * 100 )
99+ print ("=" * 100 )
0 commit comments