Skip to content

Plumb model_cfg.fp32_output through Float16Module#52

Merged
Randomizez merged 1 commit into
stepfun-ai:devfrom
yuruofeifei:fix/fp32-output-plumbing
May 5, 2026
Merged

Plumb model_cfg.fp32_output through Float16Module#52
Randomizez merged 1 commit into
stepfun-ai:devfrom
yuruofeifei:fix/fp32-output-plumbing

Conversation

@yuruofeifei

Copy link
Copy Markdown
Contributor

Summary

  • Float16Module accepts an fp32_output kwarg (default True) and honors
    it in forward, but neither trainer forwards it: lm_trainer.py:352 and
    packed_model.py:107 instantiate Float16Module(model_module, params_dtype)
    with two args only. The kwarg has been silently dead.
  • This patch forwards getattr(model_config, "fp32_output", True) from both
    trainers and declares fp32_output: bool = True on MegatronPPModelConfig
    (sibling to fp32_residual_connection) so the flag has a discoverable home.
  • Default behavior is preserved (True); recipes that opt out now correctly
    skip the fp32 logit upcast in Float16Module.forward, bounding transient
    memory on long-seq + large-vocab bf16 configs.

Repro

  • Recipe: TP=4, vocab≈248k, seq=128k, bf16, model_cfg.fp32_output = False.
  • Before this patch: setting the flag is a no-op; logits upcast to fp32 and
    add ~14 GB transient activation, OOMing on borderline configs.
  • After this patch: same flag now correctly suppresses the upcast.

Verification

  • Manually verified on the repro recipe at TP=2: previously OOM'd, now trains
    successfully with model_cfg.fp32_output = False. The fp32 logit upcast is
    correctly skipped and the ~14 GB transient activation is reclaimed.
  • Default path (fp32_output=True) is unchanged: getattr(..., True) keeps
    prior behavior for all existing recipes.

No unit test added — the fix is config plumbing; the kwarg's runtime
behavior in Float16Module.forward is exercised by every existing fp16/bf16
training run.

🤖 Generated with Claude Code

The fp32_output kwarg on Float16Module was never forwarded by either
trainer (lm_trainer.py and packed_model.py), so its consumer in
Float16Module.forward fell back to its default of True regardless of
the model config. Setting model_cfg.fp32_output=False was silently a
no-op, and the final logits always upcast to fp32.

This patch:
- Forwards getattr(model_config, "fp32_output", True) in both trainers.
- Declares fp32_output: bool = True on MegatronPPModelConfig (sibling
  to fp32_residual_connection) so it is discoverable.

Default behavior is preserved -- existing recipes are unaffected.
Recipes that opt out (model_cfg.fp32_output = False) now correctly
skip the fp32 logit upcast, which bounds transient memory on long-seq
+ large-vocab bf16 configs (e.g. vocab=248k, seq=128k saves ~14 GB).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@yuruofeifei yuruofeifei requested a review from a team May 5, 2026 04:37
@Randomizez Randomizez merged commit 850b428 into stepfun-ai:dev May 5, 2026
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants