Skip to content

[training, perf] fix: THD-aware FLOPS via cu_seqlens (Σᵢ sᵢ²)#3839

Open
cuichenx wants to merge 23 commits into
mainfrom
chcui/thd-flops-fix
Open

[training, perf] fix: THD-aware FLOPS via cu_seqlens (Σᵢ sᵢ²)#3839
cuichenx wants to merge 23 commits into
mainfrom
chcui/thd-flops-fix

Conversation

@cuichenx

@cuichenx cuichenx commented May 15, 2026

Copy link
Copy Markdown
Contributor

Summary

The FLOPS calculator treated every batch as BSHD with the attention term scaling as pack_length² (seq_length² after #3529). For THD packed training — both offline-packed LLM SFT (via PackedSequenceSpecs) and VLM in-batch packing — the actual attention work is Σᵢ sᵢ² over the sub-sequence lengths within each pack, not pack_length². This PR threads cu_seqlens from the dataloader into the FLOPS accumulator so the attention term reports the true work, and fixes the per-interval throughput logging to use the same source.

The fix

accumulate_flops_metadata() in src/megatron/bridge/training/utils/flop_utils.py, called once per micro-batch from the step functions:

  • Accumulates seqlen_sum (Σ padded tokens — drives the linear MLP/proj/logit terms) and the THD attention term seqlen_sq_sum = Σᵢ sᵢ², computed inline from cu_seqlens (preferring cu_seqlens_unpadded when pad_seq_to_mult > 1; the matching *_argmin truncates the padded tail). The squared sub-sequence lengths stay on-device (_scalar_sum_for_accumulator, no .item()), so the per-micro-batch forward path takes no host sync.
  • Falls back to the BSHD value mbs * seq_len² when cu_seqlens is absent or degenerate (dense pretraining / non-packed) — bit-exact with pre-fix behavior.
  • Also accumulates num_vision_patches (model-agnostic: each step computes the count from its own image/video grid) for the ViT term.

Per step, the main loop resolves the accumulators to a data-parallel-global total via resolve_global_flops_seqlen_stats — an exact SUM all-reduce over the pure DP group (CP ranks share the same cu_seqlens, so reducing over DP×CP would double-count) — feeds num_floating_point_operations, and folds the result into floating_point_operations_so_far. One tiny (3-element) reduce + one host sync per step, at the existing end-of-step sync boundary.

Wired into the step functions:

  • gpt_step.py — LLM, offline THD via PackedSequenceSpecs
  • vlm_step.py — legacy VLM in-batch packing
  • qwen_vl/qwen3_vl_step.py — Qwen3-VL family in-batch packing
  • qwen_omni/qwen3_omni_step.py — Qwen3-Omni (vision-patch tracking + BSHD seqlen²)

Throughput logging fix

training_log now derives the per-interval throughput/tflops (WandB / TensorBoard / MLflow / Comet / console) from the cumulative FLOPs delta since the last log ÷ the interval elapsed time — numerator and denominator over the same window. Previously it used the log-boundary step's FLOPS ÷ the interval-average iteration time, which over- or under-reports under variable-length THD (the per-step FLOPS vary). The cumulative floating_point_operations_so_far is now the single source of truth, so the logged throughput cannot disagree with it, and training_log's own resolve_… call (a second per-log DP all-reduce) is removed.

Why

For a packed batch of length L_pack holding N real sub-sequences of lengths s₁ … s_N:

  • BSHD approximation (old): attention FLOPS ∝ L_pack²
  • THD truth (new): attention FLOPS ∝ Σᵢ sᵢ²

Equality only when N == 1 (a single sub-seq fills the pack). For typical packed SFT or in-batch VLM packing, Σᵢ sᵢ² ≪ L_pack² because cross-sub-seq attention does not happen.

This is a metric correction, not a throughput regression

Replacing the inflated BSHD pack_length² with the accurate (smaller) THD Σᵢ sᵢ² lowers the reported TFLOP/s by ~7% on packed SFT — at unchanged wall-clock. Confirmed on cw-dfw (qwen3_8b_sft, DP=8):

Variant s/step (wall-clock) reported TFLOP/s/GPU
baseline / BSHD over-count 1.3666 ~578
this PR (THD truth) 1.3623–1.3671 ~543

Same clock, lower number — the old number was an over-count. The reported-TFLOP/s drop is the fix working, not a slowdown. (An earlier iteration of this PR added precompute/defer/buffer machinery to "recover" the 7%; that was chasing a measurement artifact in the reported metric, so it has been removed in favor of the straightforward inline implementation above. CI perf goldens for the affected recipes need re-baselining to the corrected values — tracked separately.)

Verification (FLOPS-value correctness)

Two identical workspaces on cw-dfw — only the source files differ. Same model, dataset, seed, parallelism. Identical loss and iteration times across paired runs confirm only the reported FLOPS changes.

Recipe seq Packing Baseline (BSHD) Fix (THD) Baseline over-count
qwen3_8b_sft_config 2048 offline THD (SQuAD) 162.6 TFLOP/s/GPU 155.8 TFLOP/s/GPU +4.4%
qwen3_8b_sft_config 4096 offline THD (SQuAD) 339.9 TFLOP/s/GPU 156.7 TFLOP/s/GPU +117% (2.17×)
qwen35_vl_9b_sft_config 4096 in-batch (CORD-v2) 261.6 TFLOP/s/GPU 88.9 TFLOP/s/GPU +194% (~3×)

The seq=2048 vs 4096 pair on the same LLM recipe is the cleanest demonstration: the fix reports near-identical TFLOP/s/GPU (155.8 vs 156.7) because actual attention work is set by per-sample lengths, not pack length. The baseline doubles (162.6 → 339.9) because its pack_length² formula scales with pack size while the real work doesn't — a clear, growing over-count.

Tests

tests/unit_tests/training/utils/test_flop_utils.py:

  • BSHD fallback (no cu_seqlens) → mbs * seq_len²
  • THD with cu_seqlens → Σᵢ sᵢ²; padded cu_seqlens + cu_seqlens_argmin truncation; cu_seqlens_unpadded precedence; additive accumulation across micro-batches; zero-length sub-seq contributes 0; 32×256 pack ≈ 32× smaller than BSHD; tokens-None no-op; inline-path computes without buffering
  • resolve_global_flops_seqlen_stats: extrapolation fallback, VPP correction, empty → None, scalar-tensor coercion, dp_group-ignored-when-dist-not-initialized

tests/unit_tests/training/test_gpt_step.py: get_batch 10-tuple (packed-seq metadata) unchanged.

tests/unit_tests/training/utils/test_train_utils.py: training_log throughput = interval-FLOPs-delta ÷ interval-time ÷ world-size, and the _flops_at_last_log anchor advances.

Out of scope

  • llava_step.py and audio_lm_step.py still fall back to cfg.model.seq_length (helper not wired). Easy follow-up.
  • Hybrid model path (hybrid_flops) still passes scalar seq_len, dropping quadratic accuracy for hybrid attention — independent of this PR.
  • CP > 1: the linear term uses per-CP-rank tokens while Σᵢ sᵢ² comes from the full (un-sharded) cu_seqlens — a cp_size scale mismatch. Not introduced here and not exercised by the CP=1 verification runs. Fixed in follow-up [training, perf] fix: CP-correct token count in FLOPS (cp_size) #4161.

Labels

bug · area:perf · area:training · needs-review

Packed THD training (offline-packed LLM SFT and VLM in-batch packing)
over-counts attention FLOPS by treating the whole pack as one length-
seq_length sequence (pack_length²). Actual attention work is Σᵢ sᵢ²
over the real sub-sequence lengths.

New helper accumulate_flops_metadata() in flop_utils.py extracts the
real sub-seq lengths from cu_seqlens (preferring cu_seqlens_unpadded
when present) and feeds Σᵢ sᵢ² into the existing seqlen_squared_sum
accumulator from #3529. Falls back to BSHD mbs * seq_len² when no
cu_seqlens is provided — bit-exact identical to legacy on dense
pretraining and non-packed paths.

Wired into gpt_step, vlm_step, qwen3_vl_step, and qwen3_omni_step.

Verified on cw-dfw (same seed, same data, same iter times, identical
loss values across paired runs — only the reported TFLOPS differs):
- qwen3_8b_sft  seq=2048: baseline 162.6 vs fix 155.8 TFLOP/s/GPU (+4%)
- qwen3_8b_sft  seq=4096: baseline 339.9 vs fix 156.7 TFLOP/s/GPU (+117%)
- qwen35_vl_9b_sft     : baseline 261.6 vs fix  88.9 TFLOP/s/GPU (+194%)

The seq=2048→4096 pair on the same LLM recipe is the cleanest
demonstration: the fix is near-flat (155.8 vs 156.7 — attention work
is determined by per-sample lengths, not pack length) while the
baseline doubles because its pack_length² scales quadratically.

9 new unit tests in test_flop_utils.py::TestAccumulateFlopsMetadata
cover the BSHD fallback, THD with cu_seqlens, padded cu_seqlens via
cu_seqlens_argmin, cu_seqlens_unpadded precedence, additive
accumulation, and the regression headline (32-sample pack → 32x
smaller attention work than BSHD approximation).

Signed-off-by: Chen Cui <chcui@nvidia.com>
@cuichenx cuichenx added bug Something isn't working area:perf Performance optimizations and benchmarking area:training Training loop, callbacks, and runtime integration needs-review PR is ready for code review and waiting on a reviewer labels May 15, 2026
@yaoyu-33 yaoyu-33 removed the area:training Training loop, callbacks, and runtime integration label May 18, 2026
…lict-nora

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

# Conflicts:
#	src/megatron/bridge/training/utils/flop_utils.py
@copy-pr-bot

copy-pr-bot Bot commented May 28, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@yaoyu-33

Copy link
Copy Markdown
Contributor

/ok to test 2edde36

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
@yaoyu-33

Copy link
Copy Markdown
Contributor

/ok to test a78f863

Comment thread src/megatron/bridge/training/utils/flop_utils.py Outdated
Comment thread src/megatron/bridge/training/vlm_step.py Outdated
@gautham-kollu

Copy link
Copy Markdown
Contributor

Failing UT logs

2026-06-10T23:09:45.9428001Z =================================== FAILURES =================================== 2026-06-10T23:09:45.9429107Z �[31m�[1m_ TestLocalCheckpointing.test_local_checkpoint_save_resume_with_most_recent_k __�[0m 2026-06-10T23:09:45.9430153Z �[1m�[31mtests/functional_tests/test_groups/training/test_local_checkpointing.py�[0m:325: in test_local_checkpoint_save_resume_with_most_recent_k 2026-06-10T23:09:46.0485380Z �[0mpretrain(cfg_run2, forward_step, callbacks=[cb])�[90m�[39;49;00m 2026-06-10T23:09:46.0486043Z �[1m�[31msrc/megatron/bridge/utils/decorators.py�[0m:39: in wrapper 2026-06-10T23:09:46.0490963Z �[0m�[94mreturn�[39;49;00m func(*args, **kwargs)�[90m�[39;49;00m 2026-06-10T23:09:46.0491453Z ^^^^^^^^^^^^^^^^^^^^^�[90m�[39;49;00m 2026-06-10T23:09:46.0493184Z �[1m�[31msrc/megatron/bridge/training/pretrain.py�[0m:98: in pretrain 2026-06-10T23:09:46.0497133Z �[0m_pretrain(state=state, forward_step_func=forward_step_func, callback_manager=callback_manager)�[90m�[39;49;00m 2026-06-10T23:09:46.0497842Z �[1m�[31msrc/megatron/bridge/training/pretrain.py�[0m:142: in _pretrain 2026-06-10T23:09:46.0499175Z �[0mtrain(�[90m�[39;49;00m 2026-06-10T23:09:46.0502078Z �[1m�[31msrc/megatron/bridge/training/train.py�[0m:455: in train 2026-06-10T23:09:46.0502504Z �[0m) = wrapped_train_step(�[90m�[39;49;00m 2026-06-10T23:09:46.0502970Z �[1m�[31msrc/megatron/bridge/training/train.py�[0m:882: in train_step 2026-06-10T23:09:46.0506169Z �[0mlosses_reduced = forward_backward_func(�[90m�[39;49;00m 2026-06-10T23:09:46.0507868Z �[1m�[31m3rdparty/Megatron-LM/megatron/core/pipeline_parallel/schedules.py�[0m:743: in forward_backward_no_pipelining 2026-06-10T23:09:46.0509872Z �[0moutput_tensor, num_tokens = forward_step(�[90m�[39;49;00m 2026-06-10T23:09:46.0510683Z �[1m�[31m3rdparty/Megatron-LM/megatron/core/pipeline_parallel/schedules.py�[0m:476: in forward_step 2026-06-10T23:09:46.0513622Z �[0moutput_tensor, num_tokens = forward_step_calc_loss(�[90m�[39;49;00m 2026-06-10T23:09:46.0515473Z �[1m�[31m3rdparty/Megatron-LM/megatron/core/pipeline_parallel/schedules.py�[0m:292: in forward_step_calc_loss 2026-06-10T23:09:46.0520732Z �[0moutputs = loss_func(output_tensor)�[90m�[39;49;00m 2026-06-10T23:09:46.0521434Z ^^^^^^^^^^^^^^^^^^^^^^^^�[90m�[39;49;00m 2026-06-10T23:09:46.0521929Z �[1m�[31msrc/megatron/bridge/training/losses.py�[0m:73: in masked_next_token_loss 2026-06-10T23:09:46.0524298Z �[0mrerun_state_machine.validate_result(�[90m�[39;49;00m 2026-06-10T23:09:46.0524838Z �[1m�[31m3rdparty/Megatron-LM/megatron/core/rerun_state_machine.py�[0m:533: in validate_result 2026-06-10T23:09:46.0527741Z �[0m�[94mraise�[39;49;00m �[96mRuntimeError�[39;49;00m(full_message)�[90m�[39;49;00m 2026-06-10T23:09:46.0528574Z �[1m�[31mE RuntimeError: Rank 1, node 70b80f25ccf5, device 1, iteration 8: Unexpected result nan (message='found NaN in local forward loss calculation')�[0m

@gautham-kollu

Copy link
Copy Markdown
Contributor

/ok to test 585800f

@cuichenx

cuichenx commented Jun 12, 2026

Copy link
Copy Markdown
Contributor Author

error not reproducible with release container. pulling gha container to debug
edit: also not reproducible with gha container. test is flaky

@cuichenx

Copy link
Copy Markdown
Contributor Author

/ok to test 5d61c29

Signed-off-by: Chen Cui <chcui@nvidia.com>
@cuichenx

Copy link
Copy Markdown
Contributor Author

/ok to test 9d3f709

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:perf Performance optimizations and benchmarking bug Something isn't working needs-more-tests Requires additional L0 and L1 test coverage before merge ready-to-merge PR is approved, current, and only waiting for CI to pass before merge

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants