Skip to content

Commit fafaa0c

Browse files
tdenejalbericiola
andcommitted
Add RL token throughput and packing metrics
Co-authored-by: Jorge Albericio <jalbericiola@nvidia.com>
1 parent f8becec commit fafaa0c

File tree

4 files changed

+202
-0
lines changed

4 files changed

+202
-0
lines changed

megatron/rl/rl_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@
5757
get_sequence_packing_tensorboard_metrics,
5858
get_sequence_packing_log_info,
5959
get_default_packed_seq_params,
60+
get_packing_actual_tokens,
61+
get_packing_compute_tokens,
62+
get_packing_efficiency,
63+
get_packing_avg_seq_length,
6064
update_microbatch_calculator,
6165
)
6266
from megatron.rl.agent.api import (
@@ -300,11 +304,22 @@ def __init__(self):
300304
self.last_collection_iteration = 0
301305
self.sequences_this_iteration_on_rank = 0
302306
self.latest_batch_num_sequences = 0
307+
# Derived throughput metrics (set by training_log, read by RLProfiler)
308+
self.tokens_per_sec = None
309+
self.tokens_per_sec_per_gpu = None
310+
self.actual_tokens_per_sec = None
311+
self.actual_tokens_per_sec_per_gpu = None
312+
self.packing_efficiency = None
303313

304314
def reset_iteration_counters(self, iteration):
305315
"""Reset per-iteration counters."""
306316
self.sequences_this_iteration_on_rank = 0
307317
self.last_collection_iteration = iteration
318+
self.tokens_per_sec = None
319+
self.tokens_per_sec_per_gpu = None
320+
self.actual_tokens_per_sec = None
321+
self.actual_tokens_per_sec_per_gpu = None
322+
self.packing_efficiency = None
308323

309324
def increment_sequences(self, count):
310325
"""Increment the sequence counter."""

megatron/rl/sequence_packing_utils.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,3 +1173,79 @@ def get_sequence_packing_tensorboard_metrics(args):
11731173
metrics['bin-batch-size'] = bin_batch_size
11741174
metrics['consumed-bins'] = args.consumed_train_bins
11751175
return metrics
1176+
1177+
1178+
def get_packing_actual_tokens(packing_context: PackingContext) -> int:
1179+
"""Get the actual number of tokens (non-padding) in the packed sequences for this rank.
1180+
1181+
Args:
1182+
packing_context: The PackingContext containing packing information.
1183+
1184+
Returns:
1185+
Total number of actual tokens across all bins on this rank.
1186+
"""
1187+
if packing_context is None or packing_context.packing_info is None:
1188+
return 0
1189+
1190+
packing_info = packing_context.packing_info
1191+
my_bin_seq_indices = packing_info.bin_seq_indices
1192+
1193+
# Sum the actual sequence lengths for all sequences in bins assigned to this rank
1194+
actual_tokens = sum(
1195+
packing_info.seq_lengths[idx]
1196+
for indices in my_bin_seq_indices
1197+
for idx in indices
1198+
)
1199+
return actual_tokens
1200+
1201+
1202+
def get_packing_compute_tokens(packing_context: PackingContext) -> int:
1203+
"""Get the total compute tokens (including padding) for packed sequences on this rank.
1204+
1205+
Args:
1206+
packing_context: The PackingContext containing packing information.
1207+
1208+
Returns:
1209+
Total compute tokens (num_bins * bin_size) on this rank.
1210+
"""
1211+
if packing_context is None or packing_context.packed_trajs is None:
1212+
return 0
1213+
1214+
packed_trajs = packing_context.packed_trajs
1215+
return packed_trajs.shape[0] * packed_trajs.shape[1]
1216+
1217+
1218+
def get_packing_efficiency(packing_context: PackingContext) -> float:
1219+
"""Get the packing efficiency (actual_tokens / total_capacity) across all DP ranks.
1220+
1221+
Args:
1222+
packing_context: The PackingContext containing packing information.
1223+
1224+
Returns:
1225+
Packing efficiency as a float between 0 and 1.
1226+
"""
1227+
if packing_context is None or packing_context.packing_info is None:
1228+
return 0.0
1229+
1230+
total_actual_tokens = sum(packing_context.packing_info.seq_lengths)
1231+
num_ranks = mpu.get_data_parallel_world_size()
1232+
bins_per_rank = packing_context.packed_trajs.shape[0] if packing_context.packed_trajs is not None else 0
1233+
bin_size = packing_context.packed_trajs.shape[1] if packing_context.packed_trajs is not None else 0
1234+
total_capacity = bins_per_rank * bin_size * num_ranks
1235+
1236+
if total_capacity == 0:
1237+
return 0.0
1238+
1239+
return total_actual_tokens / total_capacity
1240+
1241+
1242+
def get_packing_avg_seq_length(packing_context: PackingContext) -> float:
1243+
"""Get the average sequence length across all sequences in the packing context."""
1244+
if packing_context is None or packing_context.packing_info is None:
1245+
return 0.0
1246+
1247+
seq_lengths = packing_context.packing_info.seq_lengths
1248+
if not seq_lengths or len(seq_lengths) == 0:
1249+
return 0.0
1250+
1251+
return sum(seq_lengths) / len(seq_lengths)

megatron/training/training.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2165,6 +2165,66 @@ def training_log(
21652165
total_loss_dict[skipped_iters_key]
21662166
)
21672167
log_string += ' number of nan iterations: {:3d} |'.format(total_loss_dict[nan_iters_key])
2168+
2169+
# Compute tokens/sec metrics for logging
2170+
tokens_per_sec = None
2171+
tokens_per_sec_per_gpu = None
2172+
actual_tokens_per_sec = None
2173+
actual_tokens_per_sec_per_gpu = None
2174+
packing_efficiency = None
2175+
2176+
if hasattr(args, 'seq_length') and args.seq_length > 0:
2177+
# Compute tokens (includes padding for consistency with tensor shapes)
2178+
tokens_per_iteration = batch_size * args.seq_length
2179+
tokens_per_sec = tokens_per_iteration / elapsed_time_per_iteration
2180+
tokens_per_sec_per_gpu = tokens_per_sec / args.world_size
2181+
2182+
# For sequence packing, also compute actual tokens (non-padding)
2183+
if has_rl_utils and getattr(args, 'perform_rl_step', False) and getattr(args, 'rl_use_sequence_packing', False):
2184+
runtime_state = rl_utils.get_rl_runtime_state()
2185+
if runtime_state.packing_context is not None:
2186+
# Get actual tokens from packing context
2187+
actual_tokens = rl_utils.get_packing_actual_tokens(runtime_state.packing_context)
2188+
compute_tokens = rl_utils.get_packing_compute_tokens(runtime_state.packing_context)
2189+
2190+
# Scale to global batch (all DP ranks)
2191+
actual_tokens_global = actual_tokens * mpu.get_data_parallel_world_size()
2192+
2193+
actual_tokens_per_sec = actual_tokens_global / elapsed_time_per_iteration
2194+
actual_tokens_per_sec_per_gpu = actual_tokens_per_sec / args.world_size
2195+
packing_efficiency = rl_utils.get_packing_efficiency(runtime_state.packing_context)
2196+
2197+
# Add tokens/sec to log string
2198+
log_string += f' toks/s: {tokens_per_sec:.0f} |'
2199+
log_string += f' toks/s/gpu: {tokens_per_sec_per_gpu:.0f} |'
2200+
if actual_tokens_per_sec is not None:
2201+
log_string += f' actual_toks/s: {actual_tokens_per_sec:.0f} |'
2202+
log_string += f' actual_toks/s/gpu: {actual_tokens_per_sec_per_gpu:.0f} |'
2203+
log_string += f' packing_eff: {packing_efficiency:.1%} |'
2204+
2205+
# Store derived throughput metrics on RLRuntimeState so that
2206+
# downstream consumers (e.g. RLProfiler) can read them.
2207+
if has_rl_utils and getattr(args, 'perform_rl_step', False):
2208+
runtime_state = rl_utils.get_rl_runtime_state()
2209+
runtime_state.tokens_per_sec = tokens_per_sec
2210+
runtime_state.tokens_per_sec_per_gpu = tokens_per_sec_per_gpu
2211+
runtime_state.actual_tokens_per_sec = actual_tokens_per_sec
2212+
runtime_state.actual_tokens_per_sec_per_gpu = actual_tokens_per_sec_per_gpu
2213+
runtime_state.packing_efficiency = packing_efficiency
2214+
2215+
# Log average (non-padding) sequence length. With sequence packing this
2216+
# shows how long the real sequences are; without packing it equals seq_length
2217+
# (all sequences are padded to the same length) — still useful as a baseline
2218+
# so the metric is always present for comparison.
2219+
if has_rl_utils and getattr(args, 'perform_rl_step', False):
2220+
runtime_state = rl_utils.get_rl_runtime_state()
2221+
packing_ctx = runtime_state.packing_context
2222+
if getattr(args, 'rl_use_sequence_packing', False) and packing_ctx is not None:
2223+
avg_seq_length = rl_utils.get_packing_avg_seq_length(packing_ctx)
2224+
log_string += f' avg_seq_len: {avg_seq_length:.1f} |'
2225+
elif args.log_throughput:
2226+
log_string += f' avg_seq_len: {args.seq_length} |'
2227+
21682228
if should_reset:
21692229
total_loss_dict[advanced_iters_key] = 0
21702230
total_loss_dict[skipped_iters_key] = 0

tests/unit_tests/rl/test_sequence_packing_utils.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,57 @@ def test_compute_packed_inference_logprobs_stats_shape_mismatch():
412412
assert group_stats.mean_piold_to_inf_prob is None
413413

414414

415+
def test_packing_observability_metrics():
416+
"""Test various observability metrics related to sequence packing."""
417+
418+
# 4 sequences with known lengths packed into 2 bins of size 16.
419+
# Bin 0 holds seqs 0 (len 5) and 1 (len 3) → 8 actual tokens
420+
# Bin 1 holds seqs 2 (len 10) and 3 (len 4) → 14 actual tokens
421+
seq_lengths = [5, 3, 10, 4]
422+
packing_info = sequence_packing_utils.PackingInfo(
423+
bin_seq_indices=[[0, 1], [2, 3]],
424+
seq_starts={0: [0, 5], 1: [0, 10]},
425+
seq_lengths=seq_lengths,
426+
seq_to_bin_idx=[0, 0, 1, 1],
427+
packing_algo='fifo',
428+
)
429+
430+
num_bins, bin_size = 2, 16
431+
packed_trajs = torch.zeros(num_bins, bin_size, dtype=torch.long)
432+
ctx = sequence_packing_utils.PackingContext(
433+
bin_size=bin_size,
434+
packer=None,
435+
packing_info=packing_info,
436+
original_generation_masks=None,
437+
original_trajs=None,
438+
packed_trajs=packed_trajs,
439+
packed_position_ids=None,
440+
packed_attention_mask=None,
441+
packed_loss_mask=None,
442+
)
443+
444+
# actual tokens = sum of all seq_lengths referenced by bin_seq_indices
445+
assert sequence_packing_utils.get_packing_actual_tokens(ctx) == 5 + 3 + 10 + 4
446+
447+
# compute tokens = num_bins * bin_size
448+
assert sequence_packing_utils.get_packing_compute_tokens(ctx) == 2 * 16
449+
450+
# avg seq length = mean of seq_lengths
451+
assert sequence_packing_utils.get_packing_avg_seq_length(ctx) == pytest.approx(22 / 4)
452+
453+
# efficiency = total_actual / (bins_per_rank * bin_size * num_ranks)
454+
with patch('megatron.core.mpu.get_data_parallel_world_size', return_value=4):
455+
eff = sequence_packing_utils.get_packing_efficiency(ctx)
456+
# total_actual = sum(seq_lengths) = 22, capacity = 2 * 16 * 4 = 128
457+
assert eff == pytest.approx(22 / 128)
458+
459+
# None context returns zero for all metrics
460+
assert sequence_packing_utils.get_packing_actual_tokens(None) == 0
461+
assert sequence_packing_utils.get_packing_compute_tokens(None) == 0
462+
assert sequence_packing_utils.get_packing_efficiency(None) == 0.0
463+
assert sequence_packing_utils.get_packing_avg_seq_length(None) == 0.0
464+
465+
415466
@pytest.mark.parametrize("num_sequences", [1, 10, 48, 49, 50])
416467
def test_cu_seqlens_size(num_sequences):
417468
"""Test that cu_seqlens always has a fixed size regardless of how many sequences are packed."""

0 commit comments

Comments
 (0)