3636import re
3737import time
3838from collections import defaultdict
39- from datetime import datetime
39+ from datetime import datetime , timedelta
4040
4141import numpy as np
4242
@@ -59,18 +59,27 @@ def parse_log_file(log_file):
5959 # Standard performance log: "... took X.Xs"
6060 log_pattern = re .compile (r"\[MLF.* Step=(-?\d+) Rank=(-?[\d/]+) (.*?):[\d]+\] (.*?) took ([\d.]+)s" )
6161
62+ # Format: [MLF YYYY-MM-DD HH:MM:SS,mmm ...]
63+ timestamp_prefix = r"\[MLF (\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3})"
64+
6265 # Throughput logs
63- # Read: "Read 123 bytes in 0.1234 s (X.XX GB/s) from 1 files"
66+ # Read/Write patterns to capture timestamp and bytes
67+ # Capture groups: 1: Timestamp, 2: Step, 3: Rank, 4: Bytes, 5: Duration
68+ # [Timestamp + ...] + "Read 123 bytes in 0.123 s (x.xx GB/s) from 1 buckets"
6469 read_throughput_pattern = re .compile (
65- r"\[MLF.* Step=(-?\d+) Rank=(-?[\d/]+) .*?\] Read (\d+) "
70+ timestamp_prefix + r".*? Step=(-?\d+) Rank=(-?[\d/]+) .*?\] Read (\d+) "
6671 r"bytes in ([\d.]+) s \(([\d.]+) GB/s\) from (\d+) files"
6772 )
68- # Write: "Written 123 bytes in 0.1234 s (X.XX GB/s) from 1 buckets"
73+ # [Timestamp + ...] + "Written 123 bytes in 0.123 s (X.XX GB/s) from 1 buckets"
6974 write_throughput_pattern = re .compile (
70- r"\[MLF.* Step=(-?\d+) Rank=(-?[\d/]+) .*?\] Written (\d+) "
75+ timestamp_prefix + r".*? Step=(-?\d+) Rank=(-?[\d/]+) .*?\] Written (\d+) "
7176 r"bytes in ([\d.]+) s \(([\d.]+) GB/s\) from (\d+) buckets"
7277 )
7378
79+ # Dictionary to store raw records for totalthroughput calculation
80+ # Format: { 'Read': { step: [ (end_time, duration, bytes, rank), ... ] }, 'Write': { ... } }
81+ raw_throughput_records = {"Read" : defaultdict (list ), "Write" : defaultdict (list )}
82+
7483 train_step_pattern = re .compile (r"global_step: (\d+).*?train_step_timing in s: ([\d.]+)" )
7584 data = defaultdict (lambda : defaultdict (list ))
7685 ordered_functions = []
@@ -91,35 +100,14 @@ def parse_log_file(log_file):
91100 # Check for Read Throughput
92101 read_match = read_throughput_pattern .search (line )
93102 if read_match :
94- step , rank , bytes_read , duration , mb_per_s , num_files = read_match .groups ()
95- # Track Throughput
96- metric_name = "Read Throughput (GB/s)"
97- if metric_name not in ordered_functions :
98- ordered_functions .append (metric_name )
99- data [metric_name ][int (step )].append ((float (mb_per_s ), rank ))
100-
101- # Track Duration (optional, but good for direct comparison)
102- metric_name_duration = "Read Duration (s)"
103- if metric_name_duration not in ordered_functions :
104- ordered_functions .append (metric_name_duration )
105- data [metric_name_duration ][int (step )].append ((float (duration ), rank ))
103+ _process_throughput_match (read_match , "Read" , raw_throughput_records )
106104 continue
107105
108106 # Check for Write Throughput
109107 write_match = write_throughput_pattern .search (line )
110108 if write_match :
111- step , rank , bytes_written , duration , mb_per_s , num_buckets = write_match .groups ()
112- # Track Throughput
113- metric_name = "Write Throughput (GB/s)"
114- if metric_name not in ordered_functions :
115- ordered_functions .append (metric_name )
116- data [metric_name ][int (step )].append ((float (mb_per_s ), rank ))
117-
118- # Track Duration
119- metric_name_duration = "Write Duration (s)"
120- if metric_name_duration not in ordered_functions :
121- ordered_functions .append (metric_name_duration )
122- data [metric_name_duration ][int (step )].append ((float (duration ), rank ))
109+ _process_throughput_match (write_match , "Write" , raw_throughput_records )
110+ continue
123111 continue
124112
125113 # Check for Train Step
@@ -131,7 +119,55 @@ def parse_log_file(log_file):
131119 ordered_functions .append (metric_name )
132120 data [metric_name ][int (step )].append ((float (time_taken ), "NA" ))
133121
134- return data , ordered_functions
122+ return data , ordered_functions , raw_throughput_records
123+
124+
125+ def _process_throughput_match (match , mode , raw_records ):
126+ """Extracting throughput data from a regex match."""
127+ timestamp , step , rank , bytes_val , duration , _ , _ = match .groups ()
128+ end_time = datetime .strptime (timestamp .replace ("," , "." ), "%Y-%m-%d %H:%M:%S.%f" )
129+ raw_records [mode ][int (step )].append (
130+ {"end_time" : end_time , "duration" : float (duration ), "bytes" : int (bytes_val ), "rank" : rank }
131+ )
132+
133+
134+ def _calculate_throughput_stats (entries ):
135+ """Calculating throughput from a list of entries."""
136+ if not entries :
137+ return None
138+ total_bytes = sum (e ["bytes" ] for e in entries )
139+ starts = [e ["end_time" ] - timedelta (seconds = e ["duration" ]) for e in entries ]
140+ ends = [e ["end_time" ] for e in entries ]
141+ duration = (max (ends ) - min (starts )).total_seconds ()
142+ if duration > 0 :
143+ return {"throughput" : (total_bytes / 1e9 ) / duration , "duration" : duration , "total_gb" : total_bytes / 1e9 }
144+ return None
145+
146+
147+ def calculate_total_throughput (records ):
148+ node_stats = {}
149+ for step , entries in records .items ():
150+ stats = _calculate_throughput_stats (entries )
151+ if stats :
152+ node_stats [step ] = stats
153+ return node_stats
154+
155+
156+ def calculate_per_node_throughput (records , ranks_per_node ):
157+ per_node_stats = defaultdict (lambda : defaultdict (dict ))
158+ for step , entries in records .items ():
159+ node_groups = defaultdict (list )
160+ for e in entries :
161+ try :
162+ node_id = int (e ["rank" ]) // ranks_per_node
163+ node_groups [node_id ].append (e )
164+ except ValueError :
165+ continue
166+ for node_id , node_entries in node_groups .items ():
167+ stats = _calculate_throughput_stats (node_entries )
168+ if stats :
169+ per_node_stats [step ][node_id ] = stats
170+ return per_node_stats
135171
136172
137173def analyze_step_time_breakdown (log_file_path ):
@@ -280,6 +316,50 @@ def calculate_total_training_time(log_file_path):
280316 return total_training_time
281317
282318
319+ def print_total_throughput_stats (mode , throughput_stats ):
320+ if not throughput_stats :
321+ return
322+ print (f"--- Cluster-Wide Total { mode } Throughput ---" )
323+ print (f"{ 'Step' :<8} | { 'Total Data (GB)' :<15} | { 'Time (s)' :<10} | { 'Throughput (GB/s)' :<20} " )
324+ print ("-" * 65 )
325+ all_throughput = [s ["throughput" ] for s in throughput_stats .values ()]
326+ for step in sorted (throughput_stats .keys ()):
327+ s = throughput_stats [step ]
328+ print (f"{ step :<8} | { s ['total_gb' ]:<15.2f} | { s ['duration' ]:<10.3f} | { s ['throughput' ]:<20.4f} " )
329+
330+ if len (all_throughput ) > 1 :
331+ avg = np .mean (all_throughput [1 :])
332+ print (f"Cluster-Wide Average { mode } Throughput Across Steps (Excluding first { mode } ): { avg :.4f} GB/s\n " )
333+ elif all_throughput :
334+ print (f"Cluster-Wide Average { mode } Throughput Across Steps: { np .mean (all_throughput ):.4f} GB/s\n " )
335+
336+
337+ def print_per_node_throughput_stats (mode , raw_records , ranks_per_node ):
338+ per_node_stats = calculate_per_node_throughput (raw_records , ranks_per_node )
339+ if not per_node_stats :
340+ return
341+ print (f"--- Per-Node { mode } Throughput (Ranks per Node: { ranks_per_node } ) ---" )
342+ print (f"{ 'Step' :<8} | { 'Node' :<6} | { 'Total Data (GB)' :<15} | { 'Time (s)' :<10} | { 'Throughput (GB/s)' :<20} " )
343+ print ("-" * 75 )
344+
345+ node_averages = defaultdict (list )
346+ sorted_steps = sorted (per_node_stats .keys ())
347+ for step in sorted_steps :
348+ for node_id in sorted (per_node_stats [step ].keys ()):
349+ s = per_node_stats [step ][node_id ]
350+ print (
351+ f"{ step :<8} | { node_id :<6} | { s ['total_gb' ]:<15.2f} | { s ['duration' ]:<10.3f} | { s ['throughput' ]:<20.4f} "
352+ )
353+ if len (sorted_steps ) > 1 and step != sorted_steps [0 ]:
354+ node_averages [node_id ].append (s ["throughput" ])
355+
356+ if node_averages :
357+ print (f"\n Per-Node Average { mode } Throughput Across Steps (Excluding first { mode } ):" )
358+ for node_id in sorted (node_averages .keys ()):
359+ print (f"Node { node_id } : { np .mean (node_averages [node_id ]):.4f} GB/s" )
360+ print ()
361+
362+
283363def main ():
284364 """Main function."""
285365 parser = argparse .ArgumentParser (
@@ -298,6 +378,7 @@ def main():
298378 parser .add_argument ("log_file" , nargs = "?" , default = None , help = "Path to the log file to parse." )
299379 parser .add_argument ("--src-dir" , default = "src" , help = "Source directory to scan for instrumented functions." )
300380 parser .add_argument ("--save-functions" , help = "Path to save the instrumented functions list as a JSON file." )
381+ parser .add_argument ("--ranks-per-node" , type = int , default = 8 , help = "Number of ranks per node." )
301382
302383 args = parser .parse_args ()
303384
@@ -317,7 +398,7 @@ def main():
317398 print ("*********" * 8 )
318399 print ()
319400
320- data , ordered_functions = parse_log_file (args .log_file )
401+ data , ordered_functions , raw_throughput_records = parse_log_file (args .log_file )
321402 stats = calculate_statistics (data )
322403 overall_stats = calculate_overall_statistics (data )
323404
@@ -384,6 +465,16 @@ def main():
384465 )
385466 print ()
386467
468+ # Report Total Throughput
469+ cluster_read_stats = calculate_total_throughput (raw_throughput_records ["Read" ])
470+ cluster_write_stats = calculate_total_throughput (raw_throughput_records ["Write" ])
471+ print_total_throughput_stats ("Read" , cluster_read_stats )
472+ print_total_throughput_stats ("Write" , cluster_write_stats )
473+
474+ # Report Per-Node Throughput
475+ print_per_node_throughput_stats ("Read" , raw_throughput_records ["Read" ], args .ranks_per_node )
476+ print_per_node_throughput_stats ("Write" , raw_throughput_records ["Write" ], args .ranks_per_node )
477+
387478 print ("--- Step-to-Step Time Gap Analysis ---" )
388479 print ("Note: 'Total Gap' is the wall-clock time elapsed since the previous step finished." )
389480 print (" 'Other Time' = Total Gap - Train Time." )
0 commit comments