Skip to content

[parallel] feat: Vision Data Parallel — O(1) communication alternative to patch-level SP#505

Open
aoshen524 wants to merge 5 commits intoByteDance-Seed:mainfrom
aoshen524:feat/vision-dp
Open

[parallel] feat: Vision Data Parallel — O(1) communication alternative to patch-level SP#505
aoshen524 wants to merge 5 commits intoByteDance-Seed:mainfrom
aoshen524:feat/vision-dp

Conversation

@aoshen524
Copy link

@aoshen524 aoshen524 commented Feb 24, 2026

Summary

This PR introduces Vision DP — an alternative to patch-level Sequence Parallelism for Vision Transformers that trades per-layer communication for a simpler, communication-free ViT execution with a single post-ViT all-gather.

Why this is NOT traditional Data Parallelism

Key distinction: Vision DP operates across SP ranks (which share the same micro-batch), NOT DP ranks (which have different data).

In Ulysses SP, all SP ranks must hold the identical text sequence for the sequence-parallel attention to work correctly. Since vision embeddings are embedded inline within the text sequence, all SP ranks also receive all images in the batch. This creates a unique problem:

Traditional DP Vision DP (this PR)
Ranks DP ranks (different data) SP ranks (same data)
Input Each rank has different micro-batch All ranks have the same micro-batch
Fix via dataloader? Yes — dataloader shards data No — SP ranks must share the same input
What's distributed Entire forward/backward Only the ViT forward; embeddings are all-gathered back

Why you can't just change the dataloader: SP ranks are constrained to process the same sequence. Giving them different images would break text-side sequence parallelism — the attention all-to-all would receive mismatched KV from different ranks.

How the two approaches compare

Both patch-level SP (current) and Vision DP achieve the same ~1/sp_size ViT activation memory reduction — they just partition along different axes:

Patch-level SP (current) Vision DP (this PR)
Partition axis Each rank processes 1/sp_size patches of ALL images Each rank processes ALL patches of 1/sp_size images
ViT activation memory ~total_patches/sp_size per rank ~total_patches/sp_size per rank
ViT internal comms Model-specific (see below) Zero — each rank runs ViT independently
Post-ViT 1 all-gather to collect all embeddings

The key difference is in ViT-internal communication:

Qwen2.5-VL: Vision DP eliminates 4 boundary all-to-alls

Qwen2.5-VL's ViT uses window attention, which requires seeing the full sequence to compute window_index / reverse_indices. With patch-level SP, this forces 4 boundary all-to-all operations (gather→reindex→scatter, before and after the transformer block loop).

Patch-level SP Vision DP
ViT internal 4 all-to-all (window attention boundary gather/scatter) 0
Post-ViT 1 all-gather
Model fusion 3–4 all-to-all similar
Net ViT comms 4 all-to-all → 1 all-gather

Qwen3-VL: Comparable — current approach already has zero ViT-internal comms

Qwen3-VL has no window attention, so patch-level SP already achieves zero ViT-internal all-to-all. Vision DP adds 1 all-gather but may simplify model fusion.

Patch-level SP Vision DP
ViT internal 0 0
Post-ViT 1 all-gather
Model fusion 3 all-to-all + N×all-gather (deepstack) similar
Net ViT comms Comparable — no clear winner

When to use Vision DP

  • Qwen2.5-VL: Recommended — eliminates 4 window attention boundary all-to-alls, simpler ViT implementation
  • Qwen3-VL: Optional — similar communication cost, but simpler ViT code (no cu_seqlens SP slicing or sp_pad_and_slice inside ViT)
  • General: Useful when you want a model-agnostic ViT SP solution that doesn't require per-model attention patching

Key design choices

  • Image-level distribution: avoids breaking ViT's internal cu_seqlens tracking — no need to manipulate attention boundaries across SP ranks
  • Load-balanced contiguous bin-packing: assigns images to ranks by patch count, preserving order so no reordering is needed after gather
  • Correct gradient routing: all_reduce(SUM) in backward before slicing recovers complete gradients from all SP ranks' sequence shards
  • Backward-compatible: controlled via vision_dp=True flag. Default behavior is unchanged.

Changes

  • New: veomni/distributed/sequence_parallel/vision_dp.py — core utilities (assign, prepare, gather, create_dp_vision_forward wrapper)
  • Modified: modeling_qwen3_vl.py_vision_dp flag to skip SP patches; apply_vision_dp_patch(); outer model skips post-ViT gather_seq_scatter_heads when Vision DP is active
  • Modified: modeling_qwen2_5_vl.py — same pattern, eliminating 4 all-to-all ops in ViT
  • Modified: veomni/arguments/arguments_types.py — added vision_dp: bool = False to TrainingArguments
  • Modified: veomni/models/auto.pybuild_foundation_model() accepts vision_dp and dispatches to model-specific patch
  • Modified: tasks/train_torch.py, tasks/omni/train_qwen_vl.py — pass vision_dp from training args
  • New: 21 CPU-only unit tests for all Vision DP utilities

Usage

# Enable via training argument
--train.vision_dp=true --train.ulysses_parallel_size=2

Precision Alignment: Vision DP On vs Off (verl reference experiment)

Precision alignment was validated in the verl framework (verl-project/verl#5230) under controlled conditions. The same Vision DP algorithm is shared across frameworks.

Experiment Setup

Config Value
Model Qwen2.5-VL-7B (SFT checkpoint)
GPUs 4 × H100 (1 node)
SP size 2
Strategy FSDP2
Train batch size 4
Rollout N 2
Max assistant turns 4
Dummy mode full (fixed multi-image prompts with meaningful screenshots + fixed text)
Determinism enabled (seed=42, TF32 disabled, CUBLAS_WORKSPACE_CONFIG=:4096:8)
Advantage hardcoded to 1.0 to remove reward randomness
Training steps 1

Results: Hash-Level (62 groups × 3 phases)

Phase Param SHA256 Grad SHA256 Match Vision Grad MM Language Grad MM
pre_clip (raw grad) 62/62 ✅ 29/62 32 0
before (post-clip) 62/62 ✅ 0/62 32 29
after (post-step) 0/62 0/62 32 29

Language gradients are bitwise identical at pre_clip — Vision DP does not affect the language model's gradient computation. The before/after phase divergence is expected: global grad norm (used for clipping) includes vision gradients, so a different clip factor propagates to all parameters.

Results: Per-Parameter Element-wise Diff

Scope Params max_diff mean_diff cosine_sim
vision 390 4.70e-05 2.93e-08 0.9991
language 338 9.50e-08 1.15e-10 1.0020
other 1 9.13e-08 2.25e-13 1.0001

Conclusion

  • Vision DP is numerically lossless: all differences fall within bf16 floating-point precision (~1e-05 max).
  • Root cause: all_reduce(SUM) in the backward pass changes floating-point accumulation order across SP ranks (SP=2), causing ~1e-05 level differences in vision gradients only.
  • Language is unaffected: pre-clip language gradients are bitwise identical; post-clip differences are solely due to global grad norm clipping propagation.

Test plan

  • 21 CPU-only unit tests pass (pytest tests/parallel/test_vision_dp.py -v)
  • ruff check and ruff format pass on all modified files
  • Multi-GPU distributed test with SP enabled (requires GPU cluster)

🤖 Generated with Claude Code

…el SP

Vision DP distributes whole images across SP ranks for ViT computation,
then all-gathers embeddings once. This eliminates per-layer all-to-all
communication in the ViT, which is especially beneficial for Qwen2.5-VL
(4 all-to-all ops reduced to 1 all-gather).

Key changes:
- Add veomni/distributed/sequence_parallel/vision_dp.py with:
  - Image-level load-balanced distribution (greedy contiguous bin-packing)
  - GatherVisionEmbeddings autograd function with correct gradient routing
    (all_reduce SUM before slice to recover complete gradients)
  - create_dp_vision_forward() wrapper for VisionModel.forward
- Modify Qwen3-VL and Qwen2.5-VL VisionModel.forward to accept _vision_dp
  flag that skips patch-level SP when Vision DP handles distribution
- Add apply_vision_dp_patch() with idempotency guard
- Add 21 CPU-only unit tests

Communication cost comparison:
- Patch-level SP (Qwen2.5-VL): 4 all-to-all per ViT forward
- Patch-level SP (Qwen3-VL): 1 all-to-all per ViT forward
- Vision DP (both): 1 all-gather after ViT (O(1) always)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Vision Data Parallel (Vision DP), a well-designed optimization that reduces communication overhead in Vision Transformers by distributing whole images instead of patches. The implementation is clean, backward-compatible via a feature flag, and includes a comprehensive set of unit tests. I've found a performance issue in the core vision_dp.py utility related to redundant computation and inefficient device transfers. I have provided comments with suggestions to address this, which also involves updating the corresponding unit tests. Overall, this is a high-quality and valuable contribution.

Comment on lines +145 to +150
def prepare_local_vision_inputs(
pixel_values: torch.Tensor,
grid_thw: torch.Tensor,
image_assignments: list[list[int]],
dp_rank: int,
) -> tuple[torch.Tensor, torch.Tensor, list[int]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There's a performance issue in prepare_local_vision_inputs. It re-computes patch_counts on line 180, which involves an inefficient GPU-CPU-GPU data transfer since grid_thw is on the GPU. The calling function, create_dp_vision_forward, already computes patch_counts on the CPU.

To fix this, you should:

  1. Modify this function's signature to accept patch_counts: list[int].
  2. Remove the redundant computation on line 180.
  3. Update the call site in create_dp_vision_forward (lines 333-335) to pass the patch_counts variable.
def prepare_local_vision_inputs(
    pixel_values: torch.Tensor,
    grid_thw: torch.Tensor,
    image_assignments: list[list[int]],
    dp_rank: int,
    patch_counts: list[int],
) -> tuple[torch.Tensor, torch.Tensor, list[int]]:

Comment on lines +114 to +171
class TestPrepareLocalVisionInputs:
def test_prepare_two_images_splits_correctly(self):
pixel_values = torch.randn(100, 768)
grid_thw = torch.tensor([[1, 6, 6], [1, 8, 8]]) # 36 + 64 = 100
image_assignments = [[0], [1]]

pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=0)
assert pix.shape[0] == 36
assert grid.shape[0] == 1
assert indices == [0]
assert torch.allclose(pix, pixel_values[:36])

pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=1)
assert pix.shape[0] == 64
assert grid.shape[0] == 1
assert indices == [1]
assert torch.allclose(pix, pixel_values[36:100])

def test_prepare_multiple_contiguous_images_per_rank(self):
pixel_values = torch.randn(200, 768)
grid_thw = torch.tensor([[1, 5, 10]] * 4) # 4 x 50 patches
image_assignments = [[0, 1], [2, 3]]

pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=0)
assert pix.shape[0] == 100
assert grid.shape[0] == 2
assert indices == [0, 1]
assert torch.allclose(pix, pixel_values[:100])

def test_prepare_empty_rank_returns_empty(self):
pixel_values = torch.randn(100, 768)
grid_thw = torch.tensor([[1, 10, 10]])
image_assignments = [[0], []]

pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=1)
assert pix.shape[0] == 0
assert grid.shape[0] == 0
assert indices == []

def test_prepare_grid_thw_preserved(self):
pixel_values = torch.randn(150, 768)
grid_thw = torch.tensor([[1, 5, 5], [2, 5, 5], [3, 5, 5]]) # 25 + 50 + 75
image_assignments = [[0, 1], [2]]

_, local_grid, _ = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=0)
assert local_grid.shape == (2, 3)
assert torch.equal(local_grid[0], grid_thw[0])
assert torch.equal(local_grid[1], grid_thw[1])


class TestGatherVisionEmbeddings:
def test_gather_none_group_returns_input(self):
embeddings = torch.randn(10, 64)
result = gather_vision_embeddings(embeddings, dp_group=None, all_counts=[10])
assert torch.equal(result, embeddings)


class TestIntegration:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Following the suggested change in vision_dp.py to pass patch_counts to prepare_local_vision_inputs, these tests need to be updated to pass the new argument. Note that the call in test_full_workflow_all_patches_covered on line 188 will also need to be updated similarly.

class TestPrepareLocalVisionInputs:
    def test_prepare_two_images_splits_correctly(self):
        pixel_values = torch.randn(100, 768)
        grid_thw = torch.tensor([[1, 6, 6], [1, 8, 8]])  # 36 + 64 = 100
        image_assignments = [[0], [1]]
        patch_counts = [36, 64]

        pix, grid, indices = prepare_local_vision_inputs(
            pixel_values, grid_thw, image_assignments, dp_rank=0, patch_counts=patch_counts
        )
        assert pix.shape[0] == 36
        assert grid.shape[0] == 1
        assert indices == [0]
        assert torch.allclose(pix, pixel_values[:36])

        pix, grid, indices = prepare_local_vision_inputs(
            pixel_values, grid_thw, image_assignments, dp_rank=1, patch_counts=patch_counts
        )
        assert pix.shape[0] == 64
        assert grid.shape[0] == 1
        assert indices == [1]
        assert torch.allclose(pix, pixel_values[36:100])

    def test_prepare_multiple_contiguous_images_per_rank(self):
        pixel_values = torch.randn(200, 768)
        grid_thw = torch.tensor([[1, 5, 10]] * 4)  # 4 x 50 patches
        image_assignments = [[0, 1], [2, 3]]
        patch_counts = [50, 50, 50, 50]

        pix, grid, indices = prepare_local_vision_inputs(
            pixel_values, grid_thw, image_assignments, dp_rank=0, patch_counts=patch_counts
        )
        assert pix.shape[0] == 100
        assert grid.shape[0] == 2
        assert indices == [0, 1]
        assert torch.allclose(pix, pixel_values[:100])

    def test_prepare_empty_rank_returns_empty(self):
        pixel_values = torch.randn(100, 768)
        grid_thw = torch.tensor([[1, 10, 10]])
        image_assignments = [[0], []]
        patch_counts = [100]

        pix, grid, indices = prepare_local_vision_inputs(
            pixel_values, grid_thw, image_assignments, dp_rank=1, patch_counts=patch_counts
        )
        assert pix.shape[0] == 0
        assert grid.shape[0] == 0
        assert indices == []

    def test_prepare_grid_thw_preserved(self):
        pixel_values = torch.randn(150, 768)
        grid_thw = torch.tensor([[1, 5, 5], [2, 5, 5], [3, 5, 5]])  # 25 + 50 + 75
        image_assignments = [[0, 1], [2]]
        patch_counts = [25, 50, 75]

        _, local_grid, _ = prepare_local_vision_inputs(
            pixel_values, grid_thw, image_assignments, dp_rank=0, patch_counts=patch_counts
        )
        assert local_grid.shape == (2, 3)
        assert torch.equal(local_grid[0], grid_thw[0])
        assert torch.equal(local_grid[1], grid_thw[1])

Avoids redundant get_image_patch_counts(grid_thw) call inside
prepare_local_vision_inputs when the caller already computed it on CPU.
Eliminates unnecessary GPU→CPU sync.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The final hidden states of the model.
grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
The temporal, height and width of feature shape of each image in LLM.
_vision_dp: When True, skip patch-level SP logic (Vision DP handles
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if there are not enough images to distribute among vit dp ranks?

aoshen524 and others added 3 commits March 3, 2026 22:52
…d contiguous guard

- Trim verbose docstrings to concise one-liners
- Delete dead store ctx.hidden_size (written in forward, never read in backward)
- Simplify hidden_size detection: self.config.out_hidden_size
- Add requires_grad_() for empty rank to participate in backward all_reduce
- Add .contiguous() guard before all_reduce (NCCL requirement)
- Reuse get_image_patch_counts in spatial_merge_size==1 path

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace isinstance(tuple) check with model attribute detection
(hasattr deepstack_merger_list). Empty ranks now create matching
empty deepstack tensors and participate in all-gather, preventing
NCCL deadlock when num_images < dp_size.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add `vision_dp: bool = False` to TrainingArguments and
`build_foundation_model()`. After model build, dispatch to
the appropriate vision_dp patch based on model_type
(qwen2_5_vl or qwen3_vl). Pass vision_dp from train_torch.py
and train_qwen_vl.py.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants