Skip to content

feat: Qwen3.5 VLM TP+PP support with per-microbatch grad reduce-scatter knob#1859

Open
akoumpa wants to merge 13 commits intomainfrom
akoumparouli/fix_qwen3_5_extract_model_layers
Open

feat: Qwen3.5 VLM TP+PP support with per-microbatch grad reduce-scatter knob#1859
akoumpa wants to merge 13 commits intomainfrom
akoumparouli/fix_qwen3_5_extract_model_layers

Conversation

@akoumpa
Copy link
Copy Markdown
Contributor

@akoumpa akoumpa commented Apr 15, 2026

Summary

@HuiyingLi
Enables tensor + pipeline parallelism for Qwen3.5-27B VLM end-to-end, and adds a new PipelineConfig.reduce_grad_per_microbatch knob that keeps FSDP gradients sharded across microbatches (saves ~27 GB per rank for a 13B-trainable-param stage).

Changes

Qwen3.5 VLM TP plan (optimized_tp_plans.py, parallelizer.py)

  • Register Qwen3_5ForConditionalGeneration in PARALLELIZE_FUNCTIONS; the plan delegates to get_hf_tp_shard_plan, which reads transformers' base_model_tp_plan from Qwen3_5TextConfig and prefixes it with model.language_model..
  • Add Qwen3.5 VLM to get_hf_tp_shard_plan's dispatch so inner-model nesting resolves correctly.
  • Translate transformers' new replicated_with_grad_allreduce style as a no-op under FSDP+TP (norm weights are naturally replicated on the TP mesh; FSDP handles grad sync).
  • Note: linear_attn (GatedDeltaNet) layers remain un-TP-sharded — transformers itself doesn't provide a plan for them since the stock chunk_gated_delta_rule / causal_conv1d_fn kernels aren't TP-aware.

reduce_grad_per_microbatch knob (config.py, autopipeline.py, functional.py, fsdp_mixin.py, kd.py)

  • Default False preserves current behavior (FSDP no_sync across microbatches, reduce-scatter once at the end).
  • When True, every microbatch backward calls set_requires_gradient_sync(True) so FSDP reduce-scatters per microbatch. Grads stay sharded; the full-stage no_sync accumulator (stage_trainable_params × 2 bytes) is eliminated.
  • Trade-off: N reduce-scatters per step instead of 1; memory savings ~27 GB per rank for a 13B-param stage in bf16.

Recipe: examples/vlm_finetune/qwen3_5/qwen3_5_27b_tp4pp4.yaml — 2-node (16 GPUs) tp=4, pp=4, dp=1 config with the new knob enabled.

Validation (8 GPUs, pp=2, tp=1, dp=4, lbs=4)

Config Peak per rank Outcome
default (no knob) 66 GB → OOM at step 1 bwd ❌ broken
knob=True 32–42 GB across steps 0–2 ✅ fits

Full-grad accumulator directly measured dropped from 26.9 GB (425 full-size grad tensors) to 6.7 GB (0 full-size grad tensors) after mb 0 backward.

100-step convergence run (wandb)

image

Test plan

  • Local pp=2/dp=4 validated with and without knob (OOM vs fit)
  • Local tp=2/pp=2/dp=2 with knob + new Qwen3.5 TP plan (fits, 3 steps clean)
  • 2-node tp=4/pp=4/dp=1 with knob (100 steps, wandb loss curve)
  • pp=1 baseline for convergence comparison
  • Gemma4 VLM regression check (uses same patched_backward_maybe_with_nosync path)
  • MoE (Qwen3moeVL) PP regression check (primary user of patched_backward_maybe_with_nosync)

🤖 Generated with Claude Code

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 15, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@akoumpa akoumpa added the r0.4.0 Auto-cherrypick to release branch. Apply before merge; cherrypick happens after merge. label Apr 15, 2026
@akoumpa
Copy link
Copy Markdown
Contributor Author

akoumpa commented Apr 15, 2026

/ok to test 010ad75

HuiyingLi and others added 3 commits April 16, 2026 00:38
…1813)

* fix: FSDP2 meta-device crash for Qwen3.5 GatedDeltaNet fp32 params

PR #1711 changed _should_load_before_shard to return False for multi-GPU
DP, so models stay on meta device through FSDP wrapping. This broke the
__dict__ trick in PR #1710's patch_hf_model.

Move the gate computation into _Fp32ParamHolder.forward() so FSDP's
unshard/reshard lifecycle fires naturally. Override CPAwareGatedDeltaNet
forward for both CP and non-CP paths to route through the holder.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* chore: remove test yaml not intended for PR

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix: add sentinel to prevent __getattr__ re-wrapping

Address Claude review: guard against re-wrapping __getattr__ on
repeated patch_hf_model calls by checking a class-level sentinel
attribute.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix: add upstream version comment to _forward_no_cp

Address Claude review: note the transformers version the forward was
copied from to ease future upstream diffing.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix: update MoE test expectations for _forward_no_cp path

TestForwardFastPath tests expected super().forward() to be called,
but the non-CP path now uses _forward_no_cp(). Update mocks to match.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* test: add coverage for _Fp32ParamHolder, _compute_gate, and sentinel guard

Add unit tests for:
- _Fp32ParamHolder.forward gate computation and dtype preservation
- _compute_gate routing through holder vs inline fallback
- patch_hf_model sentinel preventing __getattr__ re-wrapping

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* test: add coverage for _forward_no_cp and forward() dispatch paths

Add 14 new tests covering the critical _forward_no_cp method (lines
91-193) and forward() dispatch logic (lines 207-213) to satisfy
codecov/patch requirements for PR #1813:

- _forward_no_cp basic forward, cache_params=None, causal_conv1d_fn
  fallback, causal_conv1d_fn set, attention_mask, GQA repeat-interleave,
  _compute_gate delegation, and output dtype
- forward() dispatch when _cp_mesh is None or size <= 1, parameter
  pass-through, and extra CP kwargs
- _make_fp32_getattr fallback to AttributeError and real attr resolution

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

---------

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…mparouli/fix_qwen3_5_extract_model_layers

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@akoumpa
Copy link
Copy Markdown
Contributor Author

akoumpa commented Apr 17, 2026

/ok to test 23803c5

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@akoumpa
Copy link
Copy Markdown
Contributor Author

akoumpa commented Apr 19, 2026

/ok to test 317bf7f

- parallelizer.py: handle nn.ModuleDict in _fsdp_by_dtype, safe attr walk
  in _extract_model_layers, and use string key for Qwen3_5ForConditionalGeneration
- hf_utils.py: route pipeline forward through get_text_module so nested VLM
  text modules (model.language_model.{embed_tokens,layers,norm}) work
- finetune.py: update_seq_len per-batch to precompute pipeline stage shapes
  analytically (needed for GatedDeltaNet and VLM variable-length batches)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
…3_5_extract_model_layers

# Conflicts:
#	nemo_automodel/components/distributed/parallelizer.py
#	nemo_automodel/components/models/qwen3_5_moe/cp_linear_attn.py
#	tests/unit_tests/models/qwen3_5/test_cp_linear_attn_patch.py
#	tests/unit_tests/models/qwen3_5_moe/test_cp_linear_attn.py

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Qwen3.5 VLM TP support:
- Register Qwen3_5ForConditionalGeneration in PARALLELIZE_FUNCTIONS; plan
  delegates to get_hf_tp_shard_plan which reads transformers'
  base_model_tp_plan with prefix model.language_model. GatedDeltaNet layers
  stay un-sharded (no stock TP plan exists for linear_attn).
- Extend get_hf_tp_shard_plan dispatch to handle Qwen3.5 VLM nesting.
- Translate transformers' "replicated_with_grad_allreduce" style as a no-op
  (skip in plan) — under FSDP+TP the TP-mesh replication already behaves
  correctly for norm weights.

Per-microbatch grad reduce-scatter (PipelineConfig.reduce_grad_per_microbatch):
- When True, FSDP reduce-scatters grads every microbatch instead of
  accumulating full-size grads under no_sync until the last one. Keeps
  grads sharded across DP, saving ~stage_trainable*2 bytes per rank
  (~27 GB for a 13B-trainable-param stage in bf16). Default False.
- Plumbed through PipelineConfig -> AutoPipeline -> pipeline_model; patches
  stages via types.MethodType and stores the flag on each stage; the
  patched backward_maybe_with_nosync branches on stage._reduce_grad_per_microbatch.
- kd.py teacher pipeline config propagates the field.

Validated locally on 8 GPUs: pp=2, dp=4, lbs=4 peak drops from 66 GB (OOM
at step 1 on 80 GB GPUs) to 32-41 GB across 3 steps with knob=True.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
2-node (16 GPUs) tp=4, pp=4, dp=1 config. Uses the new
reduce_grad_per_microbatch knob to keep grads sharded across microbatches.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
…_model_layers' into akoumparouli/fix_qwen3_5_extract_model_layers

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
…_model_layers' into akoumparouli/fix_qwen3_5_extract_model_layers

# Conflicts:
#	nemo_automodel/components/distributed/parallelizer.py

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi HuiyingLi force-pushed the akoumparouli/fix_qwen3_5_extract_model_layers branch from 6e4a5ff to 79bda38 Compare April 20, 2026 03:50
@HuiyingLi HuiyingLi marked this pull request as ready for review April 20, 2026 03:55
@HuiyingLi HuiyingLi changed the title fix: add qwen3_5 to _extract_model_layers feat: Qwen3.5 VLM TP+PP support with per-microbatch grad reduce-scatter knob Apr 20, 2026
- Use Qwen/Qwen3.5-27B instead of a local checkpoint path
- Add commented-out wandb section so users know how to enable it

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

r0.4.0 Auto-cherrypick to release branch. Apply before merge; cherrypick happens after merge.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants