feat: Vision Data Parallel for VLM training with CP#1
Open
feat: Vision Data Parallel for VLM training with CP#1
Conversation
When using Context Parallelism (cp_size > 1), Ring Flash Attention splits text attention across CP ranks, but the VisionTransformer (ViT) still processes ALL images on every rank, making ViT memory the bottleneck for multi-turn VLM training with many screenshots. Vision DP distributes whole images (not patches) across CP ranks: - Before: Each of N CP ranks processes ALL images -> O(total_images) - After: Each rank processes total_images/N images -> O(total_images/N) Key design: - Image-level contiguous distribution (no reordering after all-gather) - Gradient scaling by cp_size to compensate for FSDP reduction - Supports Qwen2-VL, Qwen2.5-VL, Qwen3-VL, Qwen3-VL-MoE Adapted from verl PR #5230. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Address review comments from Gemini: 1. (Critical) Add sync_vision_grads_across_cp() to all-reduce AVG the vision tower parameter gradients across CP ranks after backward. Without this, FSDP only reduces across dp_mesh, causing ViT weights to diverge when Vision DP produces different gradients per CP rank. 2. (Medium) Replace print() with logger.info() in apply_vision_dp_patch. Gradient math: GatherVisionEmbeddings backward scales output grads by cp_size, so ViT param grads = cp_size * partial_grad. After AVG across CP: mean(cp_size * partial_k) = total_grad. Correct. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: fzyzcjy <ch271828n@outlook.com>
Co-authored-by: Yueming Yuan <yym022502@gmail.com>
Co-authored-by: Yueming Yuan <yym022502@gmail.com>
Co-authored-by: Yueming Yuan <yym022502@gmail.com>
…sample fields (radixark#548) Co-authored-by: miles-code-angel <miles.pr.bot@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Yueming Yuan <yym022502@gmail.com>
…issues Address reviewer comments (same fixes as verl PR #5230): 1. **Gradient routing fix (critical)**: Replace `grad_scaler * dp_size` with `all_reduce(SUM)` in GatherVisionEmbeddings.backward() to aggregate partial sequence gradients before slicing. Fixes silent gradient loss. 2. **sync_vision_grads_across_cp: AVG→SUM**: With the activation gradient fix, each rank's ViT backward produces partial (not scaled) param gradients. SUM (not AVG) across CP now correctly recovers the total. 3. **Load-balanced assignment**: Replace count-based chunking with greedy contiguous bin-packing that balances total patch load across ranks. 4. **Remove unnecessary all_gather**: Pass pre-computed `all_counts` from caller instead of doing all_gather in forward. 5. **Idempotency guard**: Add `_vision_dp_patched` attribute check in apply_vision_dp_patch to prevent double-wrapping. 6. **Remove Qwen3-VL-MoE dead code**: Remove unreachable qwen3_vl_moe block from apply_vision_dp_patch. 7. **GPU→CPU sync optimization**: Move `grid_thw.cpu()` to dp_vision_forward entry point. 8. **Tensor slicing**: Replace Python loop in prepare_local_vision_inputs with contiguous tensor slice using cumsum. 9. **Test improvements**: Rename tests, add load balancing test, add gather_none_group test, use parametrize. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: gongyisheng <yishenggong9437@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
…d contiguous guard - Trim verbose docstrings to concise one-liners - Delete dead store ctx.hidden_size (written in forward, never read in backward) - Simplify hidden_size detection: self.config.out_hidden_size - Add requires_grad_() for empty rank to participate in backward all_reduce - Add .contiguous() guard before all_reduce (NCCL requirement) - Reuse get_image_patch_counts in spatial_merge_size==1 path Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Detect Qwen3-VL via model attribute (hasattr deepstack_merger_list) instead of return type, so empty ranks that skip original_forward still create matching empty deepstack tensors and participate in all-gather — preventing NCCL deadlock. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add `vision_dp: bool = False` to FSDPArgs and gate both apply_vision_dp_patch() and sync_vision_grads_across_cp() behind it. Vision DP is now opt-in rather than auto-enabled when CP > 1. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…adixark#642) Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Yueming Yuan <yym022502@gmail.com>
- Replace `expected_patches = end_patch - start_patch` (always-true by Python slicing) with independent cross-check via `get_image_patch_counts(local_grid_thw)` in prepare_local_vision_inputs() - Rename tests to `test_<what>_<condition>_<expected>()` convention - Add missing tests: embedding_counts empty, contiguous coverage, gather same-storage Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sync shared utility functions with verl's stricter error handling: - get_image_patch_counts/get_image_embedding_counts: empty grid_thw raises ValueError instead of returning [] - assign_images_to_dp_ranks: validate dp_size > 0, empty patch_counts raises ValueError instead of returning empty lists - prepare_local_vision_inputs: add dp_rank bounds check, use tensor-ops for offset computation (avoid Python-list round-trip), add int() cast - GatherVisionEmbeddings.forward: dp_size<=1 raises RuntimeError, validate all_counts length, max_count==0 raises RuntimeError - GatherVisionEmbeddings.backward: assert dp_size>1, add CUDA check - dp_vision_forward: cp_size<=1 raises RuntimeError, use GatherVisionEmbeddings.apply() directly, add detailed assert messages - Update tests to match: empty→raises, add dp_size/dp_rank validation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…ixark#656) Co-authored-by: maocheng23 <35615230+maocheng23@users.noreply.github.com>
…xark#643) Co-authored-by: Yueming Yuan <yym022502@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Key changes from AReaL PR #929: - Private `_` prefix on internal functions (cleaner public API) - Simplified error handling: return empty instead of raising for edge cases - Extract `_unpack_deepstack()` helper (was inline in dp_vision_forward) - `_patch_vision_class()` as standalone function with `_VISION_CLASSES` registry - `importlib`-based patching replaces repeated try/except import blocks - Simplified `spatial_merge_size` lookup: `getattr(self, "spatial_merge_size", 1)` - Hidden size fallback: `out_hidden_size` or `hidden_size` - `.contiguous()` defensive guard in GatherVisionEmbeddings forward - `dp_size==1` short-circuit in GatherVisionEmbeddings (instead of raise) - `cp_size<=1` falls through to original_forward (instead of raise) - Remove cross-check assertion in _prepare_local_vision_inputs - Remove CUDA device check in backward (handled by NCCL) - Use `_gather_vision_embeddings` wrapper consistently (including deepstack) - Pass CPU grid_thw to _prepare_local_vision_inputs, move back to GPU after Miles-specific (kept from original): - Closure-based CP group passing (no Ulysses SP APIs) - `sync_vision_grads_across_cp()` for explicit ViT param grad sync Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Both Glm4vVisionModel and Glm4vMoeVisionModel (GLM-5 744B) share the same forward signature as Qwen series (hidden_states, grid_thw -> Tensor), so no changes needed to create_dp_vision_forward — just register them. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
vision_dp: bool = Falsein FSDPArgs — opt-in, default behavior unchangedsync_vision_grads_across_cpalso gated behind the flagKey changes
miles/utils/vision_dp.pymiles/backends/fsdp_utils/actor.pyself.args.vision_dpmiles/backends/fsdp_utils/arguments.pyvision_dp: bool = FalsetoFSDPArgsUsage
Precision Alignment (verl reference experiment)
Validated in verl-project/verl#5230 under controlled conditions (same algorithm shared across frameworks):
all_reduce(SUM)changes FP accumulation order in vision backwardTest plan
🤖 Generated with Claude Code