[training, perf] fix: THD-aware FLOPS via cu_seqlens (Σᵢ sᵢ²)#3839
[training, perf] fix: THD-aware FLOPS via cu_seqlens (Σᵢ sᵢ²)#3839cuichenx wants to merge 23 commits into
Conversation
Packed THD training (offline-packed LLM SFT and VLM in-batch packing) over-counts attention FLOPS by treating the whole pack as one length- seq_length sequence (pack_length²). Actual attention work is Σᵢ sᵢ² over the real sub-sequence lengths. New helper accumulate_flops_metadata() in flop_utils.py extracts the real sub-seq lengths from cu_seqlens (preferring cu_seqlens_unpadded when present) and feeds Σᵢ sᵢ² into the existing seqlen_squared_sum accumulator from #3529. Falls back to BSHD mbs * seq_len² when no cu_seqlens is provided — bit-exact identical to legacy on dense pretraining and non-packed paths. Wired into gpt_step, vlm_step, qwen3_vl_step, and qwen3_omni_step. Verified on cw-dfw (same seed, same data, same iter times, identical loss values across paired runs — only the reported TFLOPS differs): - qwen3_8b_sft seq=2048: baseline 162.6 vs fix 155.8 TFLOP/s/GPU (+4%) - qwen3_8b_sft seq=4096: baseline 339.9 vs fix 156.7 TFLOP/s/GPU (+117%) - qwen35_vl_9b_sft : baseline 261.6 vs fix 88.9 TFLOP/s/GPU (+194%) The seq=2048→4096 pair on the same LLM recipe is the cleanest demonstration: the fix is near-flat (155.8 vs 156.7 — attention work is determined by per-sample lengths, not pack length) while the baseline doubles because its pack_length² scales quadratically. 9 new unit tests in test_flop_utils.py::TestAccumulateFlopsMetadata cover the BSHD fallback, THD with cu_seqlens, padded cu_seqlens via cu_seqlens_argmin, cu_seqlens_unpadded precedence, additive accumulation, and the regression headline (32-sample pack → 32x smaller attention work than BSHD approximation). Signed-off-by: Chen Cui <chcui@nvidia.com>
…lict-nora Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> # Conflicts: # src/megatron/bridge/training/utils/flop_utils.py
|
/ok to test 2edde36 |
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
|
/ok to test a78f863 |
|
Failing UT logs
|
|
/ok to test 585800f |
|
error not reproducible with release container. pulling gha container to debug |
|
/ok to test 5d61c29 |
Signed-off-by: Chen Cui <chcui@nvidia.com>
|
/ok to test 9d3f709 |
Summary
The FLOPS calculator treated every batch as BSHD with the attention term scaling as
pack_length²(seq_length²after #3529). For THD packed training — both offline-packed LLM SFT (viaPackedSequenceSpecs) and VLM in-batch packing — the actual attention work isΣᵢ sᵢ²over the sub-sequence lengths within each pack, notpack_length². This PR threadscu_seqlensfrom the dataloader into the FLOPS accumulator so the attention term reports the true work, and fixes the per-interval throughput logging to use the same source.The fix
accumulate_flops_metadata()insrc/megatron/bridge/training/utils/flop_utils.py, called once per micro-batch from the step functions:seqlen_sum(Σ padded tokens — drives the linear MLP/proj/logit terms) and the THD attention termseqlen_sq_sum = Σᵢ sᵢ², computed inline fromcu_seqlens(preferringcu_seqlens_unpaddedwhenpad_seq_to_mult > 1; the matching*_argmintruncates the padded tail). The squared sub-sequence lengths stay on-device (_scalar_sum_for_accumulator, no.item()), so the per-micro-batch forward path takes no host sync.mbs * seq_len²whencu_seqlensis absent or degenerate (dense pretraining / non-packed) — bit-exact with pre-fix behavior.num_vision_patches(model-agnostic: each step computes the count from its own image/video grid) for the ViT term.Per step, the main loop resolves the accumulators to a data-parallel-global total via
resolve_global_flops_seqlen_stats— an exact SUM all-reduce over the pure DP group (CP ranks share the samecu_seqlens, so reducing over DP×CP would double-count) — feedsnum_floating_point_operations, and folds the result intofloating_point_operations_so_far. One tiny (3-element) reduce + one host sync per step, at the existing end-of-step sync boundary.Wired into the step functions:
gpt_step.py— LLM, offline THD viaPackedSequenceSpecsvlm_step.py— legacy VLM in-batch packingqwen_vl/qwen3_vl_step.py— Qwen3-VL family in-batch packingqwen_omni/qwen3_omni_step.py— Qwen3-Omni (vision-patch tracking + BSHD seqlen²)Throughput logging fix
training_lognow derives the per-intervalthroughput/tflops(WandB / TensorBoard / MLflow / Comet / console) from the cumulative FLOPs delta since the last log ÷ the interval elapsed time — numerator and denominator over the same window. Previously it used the log-boundary step's FLOPS ÷ the interval-average iteration time, which over- or under-reports under variable-length THD (the per-step FLOPS vary). The cumulativefloating_point_operations_so_faris now the single source of truth, so the logged throughput cannot disagree with it, andtraining_log's ownresolve_…call (a second per-log DP all-reduce) is removed.Why
For a packed batch of length
L_packholdingNreal sub-sequences of lengthss₁ … s_N:L_pack²Σᵢ sᵢ²Equality only when
N == 1(a single sub-seq fills the pack). For typical packed SFT or in-batch VLM packing,Σᵢ sᵢ² ≪ L_pack²because cross-sub-seq attention does not happen.This is a metric correction, not a throughput regression
Replacing the inflated BSHD
pack_length²with the accurate (smaller) THDΣᵢ sᵢ²lowers the reported TFLOP/s by ~7% on packed SFT — at unchanged wall-clock. Confirmed on cw-dfw (qwen3_8b_sft, DP=8):Same clock, lower number — the old number was an over-count. The reported-TFLOP/s drop is the fix working, not a slowdown. (An earlier iteration of this PR added precompute/defer/buffer machinery to "recover" the 7%; that was chasing a measurement artifact in the reported metric, so it has been removed in favor of the straightforward inline implementation above. CI perf goldens for the affected recipes need re-baselining to the corrected values — tracked separately.)
Verification (FLOPS-value correctness)
Two identical workspaces on cw-dfw — only the source files differ. Same model, dataset, seed, parallelism. Identical loss and iteration times across paired runs confirm only the reported FLOPS changes.
qwen3_8b_sft_configqwen3_8b_sft_configqwen35_vl_9b_sft_configThe seq=2048 vs 4096 pair on the same LLM recipe is the cleanest demonstration: the fix reports near-identical TFLOP/s/GPU (155.8 vs 156.7) because actual attention work is set by per-sample lengths, not pack length. The baseline doubles (162.6 → 339.9) because its
pack_length²formula scales with pack size while the real work doesn't — a clear, growing over-count.Tests
tests/unit_tests/training/utils/test_flop_utils.py:mbs * seq_len²Σᵢ sᵢ²; padded cu_seqlens +cu_seqlens_argmintruncation;cu_seqlens_unpaddedprecedence; additive accumulation across micro-batches; zero-length sub-seq contributes 0; 32×256 pack ≈ 32× smaller than BSHD; tokens-None no-op; inline-path computes without bufferingresolve_global_flops_seqlen_stats: extrapolation fallback, VPP correction, empty →None, scalar-tensor coercion, dp_group-ignored-when-dist-not-initializedtests/unit_tests/training/test_gpt_step.py:get_batch10-tuple (packed-seq metadata) unchanged.tests/unit_tests/training/utils/test_train_utils.py:training_logthroughput = interval-FLOPs-delta ÷ interval-time ÷ world-size, and the_flops_at_last_loganchor advances.Out of scope
llava_step.pyandaudio_lm_step.pystill fall back tocfg.model.seq_length(helper not wired). Easy follow-up.hybrid_flops) still passes scalarseq_len, dropping quadratic accuracy for hybrid attention — independent of this PR.Σᵢ sᵢ²comes from the full (un-sharded)cu_seqlens— acp_sizescale mismatch. Not introduced here and not exercised by the CP=1 verification runs. Fixed in follow-up [training, perf] fix: CP-correct token count in FLOPS (cp_size) #4161.Labels
bug·area:perf·area:training·needs-review