Skip to content

Fix region compilation fsdpv2#4022

Merged
SunMarc merged 5 commits intomainfrom
fix-region-compilation-fsdpv2
Apr 29, 2026
Merged

Fix region compilation fsdpv2#4022
SunMarc merged 5 commits intomainfrom
fix-region-compilation-fsdpv2

Conversation

@SunMarc
Copy link
Copy Markdown
Member

@SunMarc SunMarc commented Apr 28, 2026

What does this PR do?

This PR fixes region compilation with FSDPv2. compile_regions rebuilds each repeated block as torch.compile(submodule) → returns an OptimizedModule wrapper. That wrapper's call bypasses nn.Module._call_impl, so the forward/pre hooks fully_shard adds afterwards never fire → no per-layer all-gather/reshard → no overlap. Basically the same thing we did for deepspeed.

Compiling the whole model doesn't lead to speed-up, maybe we should do this as the default path ?

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@AmineDiro
Copy link
Copy Markdown
Member

fix trl.SFTTrainer slower compile 🎉

Setup (same as before): Qwen3-30B-A3B (MoE, 128 experts, 48 layers) · 2×8 H100 SXM 80GB · FSDP2 DP=16 · seq_len=16384 · SFTTrainer w/ grad ckpt, bf16, packing=True.

Required yaml addition (the only diff vs my slow-path repro):

dynamo_config:
  backend: inductor
  use_fullgraph: true
  use_regional_compilation: true

Results:

Setup MFU ms/step vs eager
raw fully_shard + eager 25.0% 3,888
raw fully_shard + per-layer compile 32.1% 3,031 1.28× faster
accelerate fsdp2_prepare_model + eager 23.4% 4,160
accelerate fsdp2_prepare_model + per-layer compile (BEFORE this PR) 9.8% 9,900 2.4× slower
accelerate fsdp2_prepare_model + per-layer compile + use_regional_compilation=true (THIS PR) ~32% ~3,000 1.28× faster — matches raw fast path

mfu_window samples on this PR over 3 logging steps: 32.55 / 31.27 / 31.81 %. Regression fully closed.

Repro script (gist): https://gist.github.com/AmineDiro/9fd331214626b60e4d421264637b3828 — same as my original repro, only the yaml gains the dynamo_config block.

@SunMarc SunMarc merged commit 917e2a9 into main Apr 29, 2026
27 of 29 checks passed
@SunMarc SunMarc deleted the fix-region-compilation-fsdpv2 branch April 29, 2026 13:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants