Skip to content

Commit 82761e7

Browse files
cuichenxclaude
andcommitted
[training, perf] fix: only use cu_seqlens THD FLOPS path for CP==1
L0_Launch_training's test_sft_example_runs_with_cp_and_packing (CP=2) hangs deterministically on this PR (4/4 runs; main at the merged commit is green). The test exercises the LLM gpt_step under context parallelism + packing, where this PR newly runs the cu_seqlens-driven THD Σᵢ sᵢ² accounting in the per-microbatch forward path. That path is only wired/validated for CP==1; under CP the batch and its cu_seqlens are CP-partitioned per rank, so the per-rank computation is not yet correct (the follow-up tracked in #4161) and was destabilizing the run. Forward cu_seqlens to accumulate_flops_metadata only when CP==1; under CP>1 fall back to the BSHD term — the exact behavior this test passed on before the THD change. CP==1 (the verified configuration) is unchanged. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Chen Cui <chcui@nvidia.com>
1 parent d443d38 commit 82761e7

1 file changed

Lines changed: 12 additions & 4 deletions

File tree

src/megatron/bridge/training/gpt_step.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -321,13 +321,21 @@ def _forward_step_common(
321321
# Accumulate FLOPS metadata across micro-batches. The THD attention term Σᵢ sᵢ² is
322322
# derived inline from cu_seqlens (kept on-device, sync-free); see
323323
# accumulate_flops_metadata. Falls back to BSHD when cu_seqlens is absent.
324+
#
325+
# The cu_seqlens-driven THD path is only wired/validated for CP == 1 in this PR.
326+
# Under context parallelism the batch (and its cu_seqlens) is CP-partitioned per
327+
# rank, so the per-rank Σᵢ sᵢ² accounting here is not yet correct — that is the
328+
# follow-up tracked in #4161. Until then, forward cu_seqlens only for CP == 1 so
329+
# CP > 1 stays on the BSHD term (the behavior this test passed on before the THD
330+
# change), instead of running the not-yet-CP-safe cu_seqlens path.
331+
cp_use_thd = pg_collection.cp.size() == 1
324332
accumulate_flops_metadata(
325333
state,
326334
tokens,
327-
cu_seqlens=cu_seqlens,
328-
cu_seqlens_argmin=cu_seqlens_argmin,
329-
cu_seqlens_unpadded=cu_seqlens_unpadded,
330-
cu_seqlens_unpadded_argmin=cu_seqlens_unpadded_argmin,
335+
cu_seqlens=cu_seqlens if cp_use_thd else None,
336+
cu_seqlens_argmin=cu_seqlens_argmin if cp_use_thd else None,
337+
cu_seqlens_unpadded=cu_seqlens_unpadded if cp_use_thd else None,
338+
cu_seqlens_unpadded_argmin=cu_seqlens_unpadded_argmin if cp_use_thd else None,
331339
)
332340

333341
forward_args = {

0 commit comments

Comments
 (0)