@@ -59,13 +59,12 @@ def parse_log_file(log_file):
5959 # Standard performance log: "... took X.Xs"
6060 log_pattern = re .compile (r"\[MLF.* Step=(-?\d+) Rank=(-?[\d/]+) (.*?):[\d]+\] (.*?) took ([\d.]+)s" )
6161
62- # Format: [MLF YYYY-MM-DD HH:MM:SS,mmm ...]
63- timestamp_prefix = r"\[MLF (\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3})"
64-
6562 # Throughput logs
6663 # Read/Write patterns to capture timestamp and bytes
67- # Capture groups: 1: Timestamp, 2: Step, 3: Rank, 4: Bytes, 5: Duration
64+ # Capture groups: 1: Timestamp, 2: Step, 3: Rank, 4: Bytes, 5: Duration 6. GB/s 7. Num files/buckets
6865 # [Timestamp + ...] + "Read 123 bytes in 0.123 s (x.xx GB/s) from 1 buckets"
66+ # Format: [MLF YYYY-MM-DD HH:MM:SS,mmm ...]
67+ timestamp_prefix = r"\[MLF (\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3})"
6968 read_throughput_pattern = re .compile (
7069 timestamp_prefix + r".*? Step=(-?\d+) Rank=(-?[\d/]+) .*?\] Read (\d+) "
7170 r"bytes in ([\d.]+) s \(([\d.]+) GB/s\) from (\d+) files"
@@ -80,6 +79,14 @@ def parse_log_file(log_file):
8079 # Format: { 'Read': { step: [ (end_time, duration, bytes, rank), ... ] }, 'Write': { ... } }
8180 raw_throughput_records = {"Read" : defaultdict (list ), "Write" : defaultdict (list )}
8281
82+ # Pattern to capture the start of a training step
83+ # Groups: 1: Timestamp, 2: Step, 3: Rank
84+ batch_start_pattern = re .compile (
85+ timestamp_prefix + r".*? Step=(-?\d+) Rank=(-?[\d/]+) .*? event=on_train_batch_start"
86+ )
87+ # Create a dictionary to track the earliest start time for each step
88+ step_start_times = {}
89+
8390 train_step_pattern = re .compile (r"global_step: (\d+).*?train_step_timing in s: ([\d.]+)" )
8491 data = defaultdict (lambda : defaultdict (list ))
8592 ordered_functions = []
@@ -108,7 +115,6 @@ def parse_log_file(log_file):
108115 if write_match :
109116 _process_throughput_match (write_match , "Write" , raw_throughput_records )
110117 continue
111- continue
112118
113119 # Check for Train Step
114120 train_match = train_step_pattern .search (line )
@@ -119,7 +125,20 @@ def parse_log_file(log_file):
119125 ordered_functions .append (metric_name )
120126 data [metric_name ][int (step )].append ((float (time_taken ), "NA" ))
121127
122- return data , ordered_functions , raw_throughput_records
128+ start_match = batch_start_pattern .search (line )
129+ if start_match :
130+ ts_str , step_idx , _ = start_match .groups ()
131+ step_idx = int (step_idx )
132+ ts_dt = datetime .strptime (ts_str .replace ("," , "." ), "%Y-%m-%d %H:%M:%S.%f" )
133+
134+ # Record the earliest timestamp (of on_train_batch_start) seen for this step (first rank that logs it)
135+ # The first step in MLF is logged as -1 (-1->1->2) while in Nemo it's 0 (0->1->2)
136+ if step_idx == - 1 :
137+ step_idx = 0
138+ step_start_times [step_idx ] = min (ts_dt , step_start_times .get (step_idx , ts_dt ))
139+ continue
140+
141+ return data , ordered_functions , raw_throughput_records , step_start_times
123142
124143
125144def _process_throughput_match (match , mode , raw_records ):
@@ -170,58 +189,43 @@ def calculate_per_node_throughput(records, ranks_per_node):
170189 return per_node_stats
171190
172191
173- def analyze_step_time_breakdown (log_file_path ):
192+ def analyze_step_time_breakdown (step_start_times , data ):
174193 """
175- Analyzes the wall-clock time gap between consecutive global steps to identify overheads .
194+ Analyzes the wall-clock time gap between consecutive global steps.
176195
177- This function calculates:
178- 1. Total Gap: Wall-clock time elapsed between the end of step N-1 and step N.
179- 2. Other Time (Overhead): Total Gap minus the reported training time (train_step_timing from NeMo logs).
196+ Using the earliest 'on_train_batch_start' event as the definitive
197+ timestamp for the beginning of each step.
180198 """
181- timestamp_pattern = re .compile (r"\[NeMo \w (\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2})" )
182- train_step_pattern = re .compile (r"global_step: (\d+).*?train_step_timing in s: ([\d.]+)" )
183-
184- step_data = []
185- last_seen_timestamp = None
186-
187- try :
188- with open (log_file_path , "r" ) as f :
189- for line in f :
190- ts_match = timestamp_pattern .search (line )
191- if ts_match :
192- try :
193- last_seen_timestamp = datetime .strptime (ts_match .group (1 ), "%Y-%m-%d %H:%M:%S" )
194- except ValueError :
195- pass
196-
197- step_match = train_step_pattern .search (line )
198- if step_match and last_seen_timestamp :
199- step_data .append (
200- {
201- "step" : int (step_match .group (1 )),
202- "finish_time" : last_seen_timestamp ,
203- "train_time" : float (step_match .group (2 )),
204- }
205- )
206- except Exception as e :
207- print (f"Error analyzing breakdown: { e } " )
208- return []
209-
210199 results = []
211- for i in range (1 , len (step_data )):
212- prev , curr = step_data [i - 1 ], step_data [i ]
213- if curr ["step" ] > prev ["step" ]:
214- time_delta = max ((curr ["finish_time" ] - prev ["finish_time" ]).total_seconds (), 0.0 )
215- other_time = time_delta - curr ["train_time" ]
216- results .append (
217- {
218- "step" : curr ["step" ],
219- "timestamp" : curr ["finish_time" ],
220- "total_gap" : time_delta ,
221- "train_time" : curr ["train_time" ],
222- "other_time" : other_time ,
223- }
224- )
200+ # We need the reported train_step_timing to calculate overhead
201+ train_timings = data .get ("train_step_timing" , {})
202+
203+ sorted_steps = sorted (step_start_times .keys ())
204+
205+ for i in range (len (sorted_steps ) - 1 ):
206+ curr_step = sorted_steps [i ]
207+ next_step = sorted_steps [i + 1 ]
208+
209+ # Calculate Total Gap between starts of two consecutive steps
210+ start_curr = step_start_times [curr_step ]
211+ start_next = step_start_times [next_step ]
212+ total_gap = (start_next - start_curr ).total_seconds ()
213+
214+ try :
215+ actual_train_time = train_timings [curr_step ][0 ][0 ]
216+ except (KeyError , IndexError , TypeError ):
217+ actual_train_time = 0.0
218+ other_time = total_gap - actual_train_time
219+
220+ results .append (
221+ {
222+ "step" : curr_step ,
223+ "timestamp" : start_curr ,
224+ "total_gap" : total_gap ,
225+ "train_time" : actual_train_time ,
226+ "other_time" : other_time ,
227+ }
228+ )
225229 return results
226230
227231
@@ -398,7 +402,7 @@ def main():
398402 print ("*********" * 8 )
399403 print ()
400404
401- data , ordered_functions , raw_throughput_records = parse_log_file (args .log_file )
405+ data , ordered_functions , raw_throughput_records , step_start_times = parse_log_file (args .log_file )
402406 stats = calculate_statistics (data )
403407 overall_stats = calculate_overall_statistics (data )
404408
@@ -484,7 +488,7 @@ def main():
484488 )
485489 print ("-" * 85 )
486490
487- breakdown = analyze_step_time_breakdown (args . log_file )
491+ breakdown = analyze_step_time_breakdown (step_start_times , data )
488492 other_times = []
489493 if not breakdown :
490494 print ("No consecutive steps or timestamps found to calculate gaps." )
0 commit comments