[parallel] feat: Vision Data Parallel — O(1) communication alternative to patch-level SP#505
[parallel] feat: Vision Data Parallel — O(1) communication alternative to patch-level SP#505aoshen524 wants to merge 5 commits intoByteDance-Seed:mainfrom
Conversation
…el SP
Vision DP distributes whole images across SP ranks for ViT computation,
then all-gathers embeddings once. This eliminates per-layer all-to-all
communication in the ViT, which is especially beneficial for Qwen2.5-VL
(4 all-to-all ops reduced to 1 all-gather).
Key changes:
- Add veomni/distributed/sequence_parallel/vision_dp.py with:
- Image-level load-balanced distribution (greedy contiguous bin-packing)
- GatherVisionEmbeddings autograd function with correct gradient routing
(all_reduce SUM before slice to recover complete gradients)
- create_dp_vision_forward() wrapper for VisionModel.forward
- Modify Qwen3-VL and Qwen2.5-VL VisionModel.forward to accept _vision_dp
flag that skips patch-level SP when Vision DP handles distribution
- Add apply_vision_dp_patch() with idempotency guard
- Add 21 CPU-only unit tests
Communication cost comparison:
- Patch-level SP (Qwen2.5-VL): 4 all-to-all per ViT forward
- Patch-level SP (Qwen3-VL): 1 all-to-all per ViT forward
- Vision DP (both): 1 all-gather after ViT (O(1) always)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request introduces Vision Data Parallel (Vision DP), a well-designed optimization that reduces communication overhead in Vision Transformers by distributing whole images instead of patches. The implementation is clean, backward-compatible via a feature flag, and includes a comprehensive set of unit tests. I've found a performance issue in the core vision_dp.py utility related to redundant computation and inefficient device transfers. I have provided comments with suggestions to address this, which also involves updating the corresponding unit tests. Overall, this is a high-quality and valuable contribution.
| def prepare_local_vision_inputs( | ||
| pixel_values: torch.Tensor, | ||
| grid_thw: torch.Tensor, | ||
| image_assignments: list[list[int]], | ||
| dp_rank: int, | ||
| ) -> tuple[torch.Tensor, torch.Tensor, list[int]]: |
There was a problem hiding this comment.
There's a performance issue in prepare_local_vision_inputs. It re-computes patch_counts on line 180, which involves an inefficient GPU-CPU-GPU data transfer since grid_thw is on the GPU. The calling function, create_dp_vision_forward, already computes patch_counts on the CPU.
To fix this, you should:
- Modify this function's signature to accept
patch_counts: list[int]. - Remove the redundant computation on line 180.
- Update the call site in
create_dp_vision_forward(lines 333-335) to pass thepatch_countsvariable.
def prepare_local_vision_inputs(
pixel_values: torch.Tensor,
grid_thw: torch.Tensor,
image_assignments: list[list[int]],
dp_rank: int,
patch_counts: list[int],
) -> tuple[torch.Tensor, torch.Tensor, list[int]]:| class TestPrepareLocalVisionInputs: | ||
| def test_prepare_two_images_splits_correctly(self): | ||
| pixel_values = torch.randn(100, 768) | ||
| grid_thw = torch.tensor([[1, 6, 6], [1, 8, 8]]) # 36 + 64 = 100 | ||
| image_assignments = [[0], [1]] | ||
|
|
||
| pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=0) | ||
| assert pix.shape[0] == 36 | ||
| assert grid.shape[0] == 1 | ||
| assert indices == [0] | ||
| assert torch.allclose(pix, pixel_values[:36]) | ||
|
|
||
| pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=1) | ||
| assert pix.shape[0] == 64 | ||
| assert grid.shape[0] == 1 | ||
| assert indices == [1] | ||
| assert torch.allclose(pix, pixel_values[36:100]) | ||
|
|
||
| def test_prepare_multiple_contiguous_images_per_rank(self): | ||
| pixel_values = torch.randn(200, 768) | ||
| grid_thw = torch.tensor([[1, 5, 10]] * 4) # 4 x 50 patches | ||
| image_assignments = [[0, 1], [2, 3]] | ||
|
|
||
| pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=0) | ||
| assert pix.shape[0] == 100 | ||
| assert grid.shape[0] == 2 | ||
| assert indices == [0, 1] | ||
| assert torch.allclose(pix, pixel_values[:100]) | ||
|
|
||
| def test_prepare_empty_rank_returns_empty(self): | ||
| pixel_values = torch.randn(100, 768) | ||
| grid_thw = torch.tensor([[1, 10, 10]]) | ||
| image_assignments = [[0], []] | ||
|
|
||
| pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=1) | ||
| assert pix.shape[0] == 0 | ||
| assert grid.shape[0] == 0 | ||
| assert indices == [] | ||
|
|
||
| def test_prepare_grid_thw_preserved(self): | ||
| pixel_values = torch.randn(150, 768) | ||
| grid_thw = torch.tensor([[1, 5, 5], [2, 5, 5], [3, 5, 5]]) # 25 + 50 + 75 | ||
| image_assignments = [[0, 1], [2]] | ||
|
|
||
| _, local_grid, _ = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=0) | ||
| assert local_grid.shape == (2, 3) | ||
| assert torch.equal(local_grid[0], grid_thw[0]) | ||
| assert torch.equal(local_grid[1], grid_thw[1]) | ||
|
|
||
|
|
||
| class TestGatherVisionEmbeddings: | ||
| def test_gather_none_group_returns_input(self): | ||
| embeddings = torch.randn(10, 64) | ||
| result = gather_vision_embeddings(embeddings, dp_group=None, all_counts=[10]) | ||
| assert torch.equal(result, embeddings) | ||
|
|
||
|
|
||
| class TestIntegration: |
There was a problem hiding this comment.
Following the suggested change in vision_dp.py to pass patch_counts to prepare_local_vision_inputs, these tests need to be updated to pass the new argument. Note that the call in test_full_workflow_all_patches_covered on line 188 will also need to be updated similarly.
class TestPrepareLocalVisionInputs:
def test_prepare_two_images_splits_correctly(self):
pixel_values = torch.randn(100, 768)
grid_thw = torch.tensor([[1, 6, 6], [1, 8, 8]]) # 36 + 64 = 100
image_assignments = [[0], [1]]
patch_counts = [36, 64]
pix, grid, indices = prepare_local_vision_inputs(
pixel_values, grid_thw, image_assignments, dp_rank=0, patch_counts=patch_counts
)
assert pix.shape[0] == 36
assert grid.shape[0] == 1
assert indices == [0]
assert torch.allclose(pix, pixel_values[:36])
pix, grid, indices = prepare_local_vision_inputs(
pixel_values, grid_thw, image_assignments, dp_rank=1, patch_counts=patch_counts
)
assert pix.shape[0] == 64
assert grid.shape[0] == 1
assert indices == [1]
assert torch.allclose(pix, pixel_values[36:100])
def test_prepare_multiple_contiguous_images_per_rank(self):
pixel_values = torch.randn(200, 768)
grid_thw = torch.tensor([[1, 5, 10]] * 4) # 4 x 50 patches
image_assignments = [[0, 1], [2, 3]]
patch_counts = [50, 50, 50, 50]
pix, grid, indices = prepare_local_vision_inputs(
pixel_values, grid_thw, image_assignments, dp_rank=0, patch_counts=patch_counts
)
assert pix.shape[0] == 100
assert grid.shape[0] == 2
assert indices == [0, 1]
assert torch.allclose(pix, pixel_values[:100])
def test_prepare_empty_rank_returns_empty(self):
pixel_values = torch.randn(100, 768)
grid_thw = torch.tensor([[1, 10, 10]])
image_assignments = [[0], []]
patch_counts = [100]
pix, grid, indices = prepare_local_vision_inputs(
pixel_values, grid_thw, image_assignments, dp_rank=1, patch_counts=patch_counts
)
assert pix.shape[0] == 0
assert grid.shape[0] == 0
assert indices == []
def test_prepare_grid_thw_preserved(self):
pixel_values = torch.randn(150, 768)
grid_thw = torch.tensor([[1, 5, 5], [2, 5, 5], [3, 5, 5]]) # 25 + 50 + 75
image_assignments = [[0, 1], [2]]
patch_counts = [25, 50, 75]
_, local_grid, _ = prepare_local_vision_inputs(
pixel_values, grid_thw, image_assignments, dp_rank=0, patch_counts=patch_counts
)
assert local_grid.shape == (2, 3)
assert torch.equal(local_grid[0], grid_thw[0])
assert torch.equal(local_grid[1], grid_thw[1])Avoids redundant get_image_patch_counts(grid_thw) call inside prepare_local_vision_inputs when the caller already computed it on CPU. Eliminates unnecessary GPU→CPU sync. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
| The final hidden states of the model. | ||
| grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): | ||
| The temporal, height and width of feature shape of each image in LLM. | ||
| _vision_dp: When True, skip patch-level SP logic (Vision DP handles |
There was a problem hiding this comment.
what if there are not enough images to distribute among vit dp ranks?
…d contiguous guard - Trim verbose docstrings to concise one-liners - Delete dead store ctx.hidden_size (written in forward, never read in backward) - Simplify hidden_size detection: self.config.out_hidden_size - Add requires_grad_() for empty rank to participate in backward all_reduce - Add .contiguous() guard before all_reduce (NCCL requirement) - Reuse get_image_patch_counts in spatial_merge_size==1 path Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace isinstance(tuple) check with model attribute detection (hasattr deepstack_merger_list). Empty ranks now create matching empty deepstack tensors and participate in all-gather, preventing NCCL deadlock when num_images < dp_size. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add `vision_dp: bool = False` to TrainingArguments and `build_foundation_model()`. After model build, dispatch to the appropriate vision_dp patch based on model_type (qwen2_5_vl or qwen3_vl). Pass vision_dp from train_torch.py and train_qwen_vl.py. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Summary
This PR introduces Vision DP — an alternative to patch-level Sequence Parallelism for Vision Transformers that trades per-layer communication for a simpler, communication-free ViT execution with a single post-ViT all-gather.
Why this is NOT traditional Data Parallelism
In Ulysses SP, all SP ranks must hold the identical text sequence for the sequence-parallel attention to work correctly. Since vision embeddings are embedded inline within the text sequence, all SP ranks also receive all images in the batch. This creates a unique problem:
Why you can't just change the dataloader: SP ranks are constrained to process the same sequence. Giving them different images would break text-side sequence parallelism — the attention all-to-all would receive mismatched KV from different ranks.
How the two approaches compare
Both patch-level SP (current) and Vision DP achieve the same ~1/sp_size ViT activation memory reduction — they just partition along different axes:
The key difference is in ViT-internal communication:
Qwen2.5-VL: Vision DP eliminates 4 boundary all-to-alls
Qwen2.5-VL's ViT uses window attention, which requires seeing the full sequence to compute
window_index/reverse_indices. With patch-level SP, this forces 4 boundary all-to-all operations (gather→reindex→scatter, before and after the transformer block loop).Qwen3-VL: Comparable — current approach already has zero ViT-internal comms
Qwen3-VL has no window attention, so patch-level SP already achieves zero ViT-internal all-to-all. Vision DP adds 1 all-gather but may simplify model fusion.
When to use Vision DP
cu_seqlensSP slicing orsp_pad_and_sliceinside ViT)Key design choices
cu_seqlenstracking — no need to manipulate attention boundaries across SP ranksall_reduce(SUM)in backward before slicing recovers complete gradients from all SP ranks' sequence shardsvision_dp=Trueflag. Default behavior is unchanged.Changes
veomni/distributed/sequence_parallel/vision_dp.py— core utilities (assign, prepare, gather,create_dp_vision_forwardwrapper)modeling_qwen3_vl.py—_vision_dpflag to skip SP patches;apply_vision_dp_patch(); outer model skips post-ViTgather_seq_scatter_headswhen Vision DP is activemodeling_qwen2_5_vl.py— same pattern, eliminating 4 all-to-all ops in ViTveomni/arguments/arguments_types.py— addedvision_dp: bool = FalsetoTrainingArgumentsveomni/models/auto.py—build_foundation_model()acceptsvision_dpand dispatches to model-specific patchtasks/train_torch.py,tasks/omni/train_qwen_vl.py— passvision_dpfrom training argsUsage
# Enable via training argument --train.vision_dp=true --train.ulysses_parallel_size=2Precision Alignment: Vision DP On vs Off (verl reference experiment)
Precision alignment was validated in the verl framework (verl-project/verl#5230) under controlled conditions. The same Vision DP algorithm is shared across frameworks.
Experiment Setup
Results: Hash-Level (62 groups × 3 phases)
Results: Per-Parameter Element-wise Diff
Conclusion
all_reduce(SUM)in the backward pass changes floating-point accumulation order across SP ranks (SP=2), causing ~1e-05 level differences in vision gradients only.Test plan
pytest tests/parallel/test_vision_dp.py -v)ruff checkandruff formatpass on all modified files🤖 Generated with Claude Code