Skip to content

Commit 150d52c

Browse files
authored
feat(scripts): use on_train_batch_start for precise step timing (#39)
Sample output: https://gist.github.com/kkkapu/35ad72c9408be8b69bab2f00c13f5df2
1 parent 6f36c9c commit 150d52c

File tree

1 file changed

+60
-56
lines changed

1 file changed

+60
-56
lines changed

scripts/parse_log_and_summarize.py

Lines changed: 60 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,12 @@ def parse_log_file(log_file):
5959
# Standard performance log: "... took X.Xs"
6060
log_pattern = re.compile(r"\[MLF.* Step=(-?\d+) Rank=(-?[\d/]+) (.*?):[\d]+\] (.*?) took ([\d.]+)s")
6161

62-
# Format: [MLF YYYY-MM-DD HH:MM:SS,mmm ...]
63-
timestamp_prefix = r"\[MLF (\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3})"
64-
6562
# Throughput logs
6663
# Read/Write patterns to capture timestamp and bytes
67-
# Capture groups: 1: Timestamp, 2: Step, 3: Rank, 4: Bytes, 5: Duration
64+
# Capture groups: 1: Timestamp, 2: Step, 3: Rank, 4: Bytes, 5: Duration 6. GB/s 7. Num files/buckets
6865
# [Timestamp + ...] + "Read 123 bytes in 0.123 s (x.xx GB/s) from 1 buckets"
66+
# Format: [MLF YYYY-MM-DD HH:MM:SS,mmm ...]
67+
timestamp_prefix = r"\[MLF (\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3})"
6968
read_throughput_pattern = re.compile(
7069
timestamp_prefix + r".*? Step=(-?\d+) Rank=(-?[\d/]+) .*?\] Read (\d+) "
7170
r"bytes in ([\d.]+) s \(([\d.]+) GB/s\) from (\d+) files"
@@ -80,6 +79,14 @@ def parse_log_file(log_file):
8079
# Format: { 'Read': { step: [ (end_time, duration, bytes, rank), ... ] }, 'Write': { ... } }
8180
raw_throughput_records = {"Read": defaultdict(list), "Write": defaultdict(list)}
8281

82+
# Pattern to capture the start of a training step
83+
# Groups: 1: Timestamp, 2: Step, 3: Rank
84+
batch_start_pattern = re.compile(
85+
timestamp_prefix + r".*? Step=(-?\d+) Rank=(-?[\d/]+) .*? event=on_train_batch_start"
86+
)
87+
# Create a dictionary to track the earliest start time for each step
88+
step_start_times = {}
89+
8390
train_step_pattern = re.compile(r"global_step: (\d+).*?train_step_timing in s: ([\d.]+)")
8491
data = defaultdict(lambda: defaultdict(list))
8592
ordered_functions = []
@@ -108,7 +115,6 @@ def parse_log_file(log_file):
108115
if write_match:
109116
_process_throughput_match(write_match, "Write", raw_throughput_records)
110117
continue
111-
continue
112118

113119
# Check for Train Step
114120
train_match = train_step_pattern.search(line)
@@ -119,7 +125,20 @@ def parse_log_file(log_file):
119125
ordered_functions.append(metric_name)
120126
data[metric_name][int(step)].append((float(time_taken), "NA"))
121127

122-
return data, ordered_functions, raw_throughput_records
128+
start_match = batch_start_pattern.search(line)
129+
if start_match:
130+
ts_str, step_idx, _ = start_match.groups()
131+
step_idx = int(step_idx)
132+
ts_dt = datetime.strptime(ts_str.replace(",", "."), "%Y-%m-%d %H:%M:%S.%f")
133+
134+
# Record the earliest timestamp (of on_train_batch_start) seen for this step (first rank that logs it)
135+
# The first step in MLF is logged as -1 (-1->1->2) while in Nemo it's 0 (0->1->2)
136+
if step_idx == -1:
137+
step_idx = 0
138+
step_start_times[step_idx] = min(ts_dt, step_start_times.get(step_idx, ts_dt))
139+
continue
140+
141+
return data, ordered_functions, raw_throughput_records, step_start_times
123142

124143

125144
def _process_throughput_match(match, mode, raw_records):
@@ -170,58 +189,43 @@ def calculate_per_node_throughput(records, ranks_per_node):
170189
return per_node_stats
171190

172191

173-
def analyze_step_time_breakdown(log_file_path):
192+
def analyze_step_time_breakdown(step_start_times, data):
174193
"""
175-
Analyzes the wall-clock time gap between consecutive global steps to identify overheads.
194+
Analyzes the wall-clock time gap between consecutive global steps.
176195
177-
This function calculates:
178-
1. Total Gap: Wall-clock time elapsed between the end of step N-1 and step N.
179-
2. Other Time (Overhead): Total Gap minus the reported training time (train_step_timing from NeMo logs).
196+
Using the earliest 'on_train_batch_start' event as the definitive
197+
timestamp for the beginning of each step.
180198
"""
181-
timestamp_pattern = re.compile(r"\[NeMo \w (\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2})")
182-
train_step_pattern = re.compile(r"global_step: (\d+).*?train_step_timing in s: ([\d.]+)")
183-
184-
step_data = []
185-
last_seen_timestamp = None
186-
187-
try:
188-
with open(log_file_path, "r") as f:
189-
for line in f:
190-
ts_match = timestamp_pattern.search(line)
191-
if ts_match:
192-
try:
193-
last_seen_timestamp = datetime.strptime(ts_match.group(1), "%Y-%m-%d %H:%M:%S")
194-
except ValueError:
195-
pass
196-
197-
step_match = train_step_pattern.search(line)
198-
if step_match and last_seen_timestamp:
199-
step_data.append(
200-
{
201-
"step": int(step_match.group(1)),
202-
"finish_time": last_seen_timestamp,
203-
"train_time": float(step_match.group(2)),
204-
}
205-
)
206-
except Exception as e:
207-
print(f"Error analyzing breakdown: {e}")
208-
return []
209-
210199
results = []
211-
for i in range(1, len(step_data)):
212-
prev, curr = step_data[i - 1], step_data[i]
213-
if curr["step"] > prev["step"]:
214-
time_delta = max((curr["finish_time"] - prev["finish_time"]).total_seconds(), 0.0)
215-
other_time = time_delta - curr["train_time"]
216-
results.append(
217-
{
218-
"step": curr["step"],
219-
"timestamp": curr["finish_time"],
220-
"total_gap": time_delta,
221-
"train_time": curr["train_time"],
222-
"other_time": other_time,
223-
}
224-
)
200+
# We need the reported train_step_timing to calculate overhead
201+
train_timings = data.get("train_step_timing", {})
202+
203+
sorted_steps = sorted(step_start_times.keys())
204+
205+
for i in range(len(sorted_steps) - 1):
206+
curr_step = sorted_steps[i]
207+
next_step = sorted_steps[i + 1]
208+
209+
# Calculate Total Gap between starts of two consecutive steps
210+
start_curr = step_start_times[curr_step]
211+
start_next = step_start_times[next_step]
212+
total_gap = (start_next - start_curr).total_seconds()
213+
214+
try:
215+
actual_train_time = train_timings[curr_step][0][0]
216+
except (KeyError, IndexError, TypeError):
217+
actual_train_time = 0.0
218+
other_time = total_gap - actual_train_time
219+
220+
results.append(
221+
{
222+
"step": curr_step,
223+
"timestamp": start_curr,
224+
"total_gap": total_gap,
225+
"train_time": actual_train_time,
226+
"other_time": other_time,
227+
}
228+
)
225229
return results
226230

227231

@@ -398,7 +402,7 @@ def main():
398402
print("*********" * 8)
399403
print()
400404

401-
data, ordered_functions, raw_throughput_records = parse_log_file(args.log_file)
405+
data, ordered_functions, raw_throughput_records, step_start_times = parse_log_file(args.log_file)
402406
stats = calculate_statistics(data)
403407
overall_stats = calculate_overall_statistics(data)
404408

@@ -484,7 +488,7 @@ def main():
484488
)
485489
print("-" * 85)
486490

487-
breakdown = analyze_step_time_breakdown(args.log_file)
491+
breakdown = analyze_step_time_breakdown(step_start_times, data)
488492
other_times = []
489493
if not breakdown:
490494
print("No consecutive steps or timestamps found to calculate gaps.")

0 commit comments

Comments
 (0)