Skip to content

feat: Vision Data Parallel for VLM training with CP#1

Open
aoshen524 wants to merge 51 commits intomainfrom
feat/vision-dp
Open

feat: Vision Data Parallel for VLM training with CP#1
aoshen524 wants to merge 51 commits intomainfrom
feat/vision-dp

Conversation

@aoshen524
Copy link
Owner

Summary

  • Vision DP distributes whole images across CP ranks for independent ViT computation, replacing redundant ViT execution on every rank
  • Single post-ViT all-gather collects embeddings back — zero ViT-internal communication
  • Supports Qwen2.5-VL and Qwen3-VL (including deepstack)
  • Gated behind vision_dp: bool = False in FSDPArgs — opt-in, default behavior unchanged
  • ViT gradient sync via sync_vision_grads_across_cp also gated behind the flag

Key changes

File Change
miles/utils/vision_dp.py Core utilities: assignment, slicing, all-gather with gradient support
miles/backends/fsdp_utils/actor.py Gate Vision DP patch + grad sync behind self.args.vision_dp
miles/backends/fsdp_utils/arguments.py Add vision_dp: bool = False to FSDPArgs

Usage

# In FSDPArgs
vision_dp = True
context_parallel_size = 2

Precision Alignment (verl reference experiment)

Validated in verl-project/verl#5230 under controlled conditions (same algorithm shared across frameworks):

Scope Params max_diff mean_diff cosine_sim
vision 390 4.70e-05 2.93e-08 0.9991
language 338 9.50e-08 1.15e-10 1.0020
other 1 9.13e-08 2.25e-13 1.0001
  • Vision DP is numerically lossless: all differences within bf16 precision (~1e-05 max)
  • Language gradients are bitwise identical at pre-clip phase
  • Root cause: all_reduce(SUM) changes FP accumulation order in vision backward

Test plan

  • Unit tests for Vision DP utilities
  • Multi-GPU distributed test with CP enabled

🤖 Generated with Claude Code

aoshen524 and others added 30 commits February 16, 2026 01:14
When using Context Parallelism (cp_size > 1), Ring Flash Attention splits
text attention across CP ranks, but the VisionTransformer (ViT) still
processes ALL images on every rank, making ViT memory the bottleneck for
multi-turn VLM training with many screenshots.

Vision DP distributes whole images (not patches) across CP ranks:
- Before: Each of N CP ranks processes ALL images -> O(total_images)
- After: Each rank processes total_images/N images -> O(total_images/N)

Key design:
- Image-level contiguous distribution (no reordering after all-gather)
- Gradient scaling by cp_size to compensate for FSDP reduction
- Supports Qwen2-VL, Qwen2.5-VL, Qwen3-VL, Qwen3-VL-MoE

Adapted from verl PR #5230.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Address review comments from Gemini:

1. (Critical) Add sync_vision_grads_across_cp() to all-reduce AVG the
   vision tower parameter gradients across CP ranks after backward.
   Without this, FSDP only reduces across dp_mesh, causing ViT weights
   to diverge when Vision DP produces different gradients per CP rank.

2. (Medium) Replace print() with logger.info() in apply_vision_dp_patch.

Gradient math: GatherVisionEmbeddings backward scales output grads by
cp_size, so ViT param grads = cp_size * partial_grad. After AVG across
CP: mean(cp_size * partial_k) = total_grad. Correct.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: fzyzcjy <ch271828n@outlook.com>
Co-authored-by: Yueming Yuan <yym022502@gmail.com>
Co-authored-by: Yueming Yuan <yym022502@gmail.com>
Co-authored-by: Yueming Yuan <yym022502@gmail.com>
…sample fields (radixark#548)

Co-authored-by: miles-code-angel <miles.pr.bot@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Yueming Yuan <yym022502@gmail.com>
…issues

Address reviewer comments (same fixes as verl PR #5230):

1. **Gradient routing fix (critical)**: Replace `grad_scaler * dp_size` with
   `all_reduce(SUM)` in GatherVisionEmbeddings.backward() to aggregate
   partial sequence gradients before slicing. Fixes silent gradient loss.

2. **sync_vision_grads_across_cp: AVG→SUM**: With the activation gradient
   fix, each rank's ViT backward produces partial (not scaled) param
   gradients. SUM (not AVG) across CP now correctly recovers the total.

3. **Load-balanced assignment**: Replace count-based chunking with greedy
   contiguous bin-packing that balances total patch load across ranks.

4. **Remove unnecessary all_gather**: Pass pre-computed `all_counts` from
   caller instead of doing all_gather in forward.

5. **Idempotency guard**: Add `_vision_dp_patched` attribute check in
   apply_vision_dp_patch to prevent double-wrapping.

6. **Remove Qwen3-VL-MoE dead code**: Remove unreachable qwen3_vl_moe
   block from apply_vision_dp_patch.

7. **GPU→CPU sync optimization**: Move `grid_thw.cpu()` to dp_vision_forward
   entry point.

8. **Tensor slicing**: Replace Python loop in prepare_local_vision_inputs
   with contiguous tensor slice using cumsum.

9. **Test improvements**: Rename tests, add load balancing test, add
   gather_none_group test, use parametrize.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: gongyisheng <yishenggong9437@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
aoshen524 and others added 21 commits March 3, 2026 22:47
…d contiguous guard

- Trim verbose docstrings to concise one-liners
- Delete dead store ctx.hidden_size (written in forward, never read in backward)
- Simplify hidden_size detection: self.config.out_hidden_size
- Add requires_grad_() for empty rank to participate in backward all_reduce
- Add .contiguous() guard before all_reduce (NCCL requirement)
- Reuse get_image_patch_counts in spatial_merge_size==1 path

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Detect Qwen3-VL via model attribute (hasattr deepstack_merger_list)
instead of return type, so empty ranks that skip original_forward
still create matching empty deepstack tensors and participate in
all-gather — preventing NCCL deadlock.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add `vision_dp: bool = False` to FSDPArgs and gate both
apply_vision_dp_patch() and sync_vision_grads_across_cp() behind it.
Vision DP is now opt-in rather than auto-enabled when CP > 1.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…adixark#642)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Yueming Yuan <yym022502@gmail.com>
- Replace `expected_patches = end_patch - start_patch` (always-true by
  Python slicing) with independent cross-check via
  `get_image_patch_counts(local_grid_thw)` in prepare_local_vision_inputs()
- Rename tests to `test_<what>_<condition>_<expected>()` convention
- Add missing tests: embedding_counts empty, contiguous coverage,
  gather same-storage

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sync shared utility functions with verl's stricter error handling:

- get_image_patch_counts/get_image_embedding_counts: empty grid_thw
  raises ValueError instead of returning []
- assign_images_to_dp_ranks: validate dp_size > 0, empty patch_counts
  raises ValueError instead of returning empty lists
- prepare_local_vision_inputs: add dp_rank bounds check, use tensor-ops
  for offset computation (avoid Python-list round-trip), add int() cast
- GatherVisionEmbeddings.forward: dp_size<=1 raises RuntimeError,
  validate all_counts length, max_count==0 raises RuntimeError
- GatherVisionEmbeddings.backward: assert dp_size>1, add CUDA check
- dp_vision_forward: cp_size<=1 raises RuntimeError, use
  GatherVisionEmbeddings.apply() directly, add detailed assert messages
- Update tests to match: empty→raises, add dp_size/dp_rank validation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…ixark#656)

Co-authored-by: maocheng23 <35615230+maocheng23@users.noreply.github.com>
…xark#643)

Co-authored-by: Yueming Yuan <yym022502@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Key changes from AReaL PR #929:
- Private `_` prefix on internal functions (cleaner public API)
- Simplified error handling: return empty instead of raising for edge cases
- Extract `_unpack_deepstack()` helper (was inline in dp_vision_forward)
- `_patch_vision_class()` as standalone function with `_VISION_CLASSES` registry
- `importlib`-based patching replaces repeated try/except import blocks
- Simplified `spatial_merge_size` lookup: `getattr(self, "spatial_merge_size", 1)`
- Hidden size fallback: `out_hidden_size` or `hidden_size`
- `.contiguous()` defensive guard in GatherVisionEmbeddings forward
- `dp_size==1` short-circuit in GatherVisionEmbeddings (instead of raise)
- `cp_size<=1` falls through to original_forward (instead of raise)
- Remove cross-check assertion in _prepare_local_vision_inputs
- Remove CUDA device check in backward (handled by NCCL)
- Use `_gather_vision_embeddings` wrapper consistently (including deepstack)
- Pass CPU grid_thw to _prepare_local_vision_inputs, move back to GPU after

Miles-specific (kept from original):
- Closure-based CP group passing (no Ulysses SP APIs)
- `sync_vision_grads_across_cp()` for explicit ViT param grad sync

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Both Glm4vVisionModel and Glm4vMoeVisionModel (GLM-5 744B) share the
same forward signature as Qwen series (hidden_states, grid_thw -> Tensor),
so no changes needed to create_dp_vision_forward — just register them.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants