Skip to content

Commit 76ba823

Browse files
aoshen524claude
andcommitted
feat(vision): add Vision DP for parallel ViT computation across SP ranks
When Ulysses sequence parallelism is enabled (sp_size > 1), the VisionTransformer processes all images on every rank redundantly. This adds Vision Data Parallel (DP) which distributes whole images across SP ranks for independent ViT processing, then all-gathers embeddings once at the end. This avoids breaking cu_seqlens semantics that would occur if patches within images were split across ranks. Supports Qwen2-VL, Qwen2.5-VL, Qwen3-VL, and Qwen3-VL-MoE. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 06449b8 commit 76ba823

File tree

3 files changed

+807
-0
lines changed

3 files changed

+807
-0
lines changed

tests/test_vision_dp.py

Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,325 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Unit tests for Vision Data Parallel utilities.
16+
"""
17+
18+
import pytest
19+
import torch
20+
21+
from verl.utils.vision_dp import (
22+
assign_images_to_dp_ranks,
23+
get_image_patch_counts,
24+
prepare_local_vision_inputs,
25+
)
26+
27+
28+
class TestGetImagePatchCounts:
29+
"""Tests for get_image_patch_counts function."""
30+
31+
def test_basic_patch_counts(self):
32+
"""Test basic patch count computation."""
33+
grid_thw = torch.tensor(
34+
[
35+
[2, 4, 4], # 2*4*4 = 32
36+
[1, 2, 2], # 1*2*2 = 4
37+
[1, 8, 8], # 1*8*8 = 64
38+
]
39+
)
40+
counts = get_image_patch_counts(grid_thw)
41+
assert counts == [32, 4, 64]
42+
43+
def test_single_image(self):
44+
"""Test with a single image."""
45+
grid_thw = torch.tensor([[1, 4, 4]]) # 16 patches
46+
counts = get_image_patch_counts(grid_thw)
47+
assert counts == [16]
48+
49+
def test_empty_input(self):
50+
"""Test with empty input."""
51+
grid_thw = torch.empty((0, 3), dtype=torch.long)
52+
counts = get_image_patch_counts(grid_thw)
53+
assert counts == []
54+
55+
def test_video_frames(self):
56+
"""Test with video (multiple temporal frames)."""
57+
grid_thw = torch.tensor(
58+
[
59+
[4, 4, 4], # 4 frames, 4*4 patches each = 64 total
60+
]
61+
)
62+
counts = get_image_patch_counts(grid_thw)
63+
assert counts == [64]
64+
65+
66+
class TestAssignImagesToDpRanks:
67+
"""Tests for assign_images_to_dp_ranks function."""
68+
69+
def test_balanced_assignment(self):
70+
"""Test balanced assignment with equal-sized images."""
71+
patch_counts = [100, 100, 100, 100]
72+
assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=2)
73+
74+
# Each rank should get 2 images
75+
assert len(assignments[0]) == 2
76+
assert len(assignments[1]) == 2
77+
# Loads should be equal
78+
assert loads[0] == 200
79+
assert loads[1] == 200
80+
81+
def test_imbalanced_images(self):
82+
"""Test with one large image and several small ones."""
83+
patch_counts = [500, 100, 100, 100] # One large image
84+
assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=2)
85+
86+
# Large image (index 0) should be on one rank
87+
# Small images should fill the other rank
88+
total_assigned = sum(len(a) for a in assignments)
89+
assert total_assigned == 4
90+
91+
# The greedy algorithm should assign large image to one rank
92+
# and remaining images to fill up the other
93+
assert 0 in assignments[0] or 0 in assignments[1]
94+
95+
def test_fewer_images_than_ranks(self):
96+
"""Test when number of images is less than dp_size."""
97+
patch_counts = [100, 200]
98+
assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=4)
99+
100+
# Only 2 ranks should have images
101+
non_empty_ranks = sum(1 for a in assignments if len(a) > 0)
102+
assert non_empty_ranks == 2
103+
104+
# All images should be assigned
105+
all_assigned = set()
106+
for a in assignments:
107+
all_assigned.update(a)
108+
assert all_assigned == {0, 1}
109+
110+
def test_empty_input(self):
111+
"""Test with no images."""
112+
patch_counts = []
113+
assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=4)
114+
115+
assert all(len(a) == 0 for a in assignments)
116+
assert all(load == 0 for load in loads)
117+
118+
def test_single_rank(self):
119+
"""Test with dp_size=1 (no parallelism)."""
120+
patch_counts = [100, 200, 300]
121+
assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=1)
122+
123+
# All images should go to the single rank
124+
assert assignments == [[0, 1, 2]]
125+
assert loads == [600]
126+
127+
def test_equal_images_equal_size(self):
128+
"""Test perfect balance: same number of equal-sized images per rank."""
129+
patch_counts = [100, 100, 100, 100, 100, 100] # 6 images
130+
assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=3)
131+
132+
# Each rank should get 2 images
133+
assert all(len(a) == 2 for a in assignments)
134+
# All loads should be equal
135+
assert all(load == 200 for load in loads)
136+
137+
def test_image_order_preserved(self):
138+
"""Test that image indices within each rank are sorted."""
139+
patch_counts = [10, 20, 30, 40, 50]
140+
assignments, _ = assign_images_to_dp_ranks(patch_counts, dp_size=2)
141+
142+
# Indices within each rank should be sorted
143+
for rank_assignment in assignments:
144+
assert rank_assignment == sorted(rank_assignment)
145+
146+
147+
class TestPrepareLocalVisionInputs:
148+
"""Tests for prepare_local_vision_inputs function."""
149+
150+
def test_basic_extraction(self):
151+
"""Test basic local input extraction."""
152+
# Create test data: 100 patches total
153+
pixel_values = torch.randn(100, 768) # 100 patches, 768 dim
154+
grid_thw = torch.tensor(
155+
[
156+
[1, 6, 6], # 36 patches (indices 0-35)
157+
[1, 8, 8], # 64 patches (indices 36-99)
158+
]
159+
)
160+
161+
# Assignment: rank 0 -> [0], rank 1 -> [1]
162+
image_assignments = [[0], [1]]
163+
164+
# Rank 0's inputs
165+
local_pix, local_grid, local_indices = prepare_local_vision_inputs(
166+
pixel_values, grid_thw, image_assignments, dp_rank=0
167+
)
168+
169+
assert local_pix.shape[0] == 36
170+
assert local_grid.shape[0] == 1
171+
assert local_indices == [0]
172+
assert torch.allclose(local_pix, pixel_values[:36])
173+
174+
# Rank 1's inputs
175+
local_pix, local_grid, local_indices = prepare_local_vision_inputs(
176+
pixel_values, grid_thw, image_assignments, dp_rank=1
177+
)
178+
179+
assert local_pix.shape[0] == 64
180+
assert local_grid.shape[0] == 1
181+
assert local_indices == [1]
182+
assert torch.allclose(local_pix, pixel_values[36:100])
183+
184+
def test_multiple_images_per_rank(self):
185+
"""Test extraction when a rank has multiple images."""
186+
# Create test data: 200 patches total (50 + 50 + 50 + 50)
187+
pixel_values = torch.randn(200, 768)
188+
grid_thw = torch.tensor(
189+
[
190+
[1, 5, 10], # 50 patches
191+
[1, 5, 10], # 50 patches
192+
[1, 5, 10], # 50 patches
193+
[1, 5, 10], # 50 patches
194+
]
195+
)
196+
197+
# Assignment: rank 0 -> [0, 2], rank 1 -> [1, 3]
198+
image_assignments = [[0, 2], [1, 3]]
199+
200+
# Rank 0's inputs (images 0 and 2)
201+
local_pix, local_grid, local_indices = prepare_local_vision_inputs(
202+
pixel_values, grid_thw, image_assignments, dp_rank=0
203+
)
204+
205+
assert local_pix.shape[0] == 100 # 50 + 50
206+
assert local_grid.shape[0] == 2
207+
assert local_indices == [0, 2]
208+
209+
# Verify correct patches are extracted
210+
expected = torch.cat([pixel_values[0:50], pixel_values[100:150]], dim=0)
211+
assert torch.allclose(local_pix, expected)
212+
213+
def test_empty_rank(self):
214+
"""Test extraction when a rank has no images assigned."""
215+
pixel_values = torch.randn(100, 768)
216+
grid_thw = torch.tensor([[1, 10, 10]]) # 100 patches
217+
218+
# Only rank 0 has the image, rank 1 is empty
219+
image_assignments = [[0], []]
220+
221+
# Rank 1's inputs (empty)
222+
local_pix, local_grid, local_indices = prepare_local_vision_inputs(
223+
pixel_values, grid_thw, image_assignments, dp_rank=1
224+
)
225+
226+
assert local_pix.shape[0] == 0
227+
assert local_grid.shape[0] == 0
228+
assert local_indices == []
229+
230+
def test_grid_thw_preserved(self):
231+
"""Test that grid_thw values are correctly extracted."""
232+
pixel_values = torch.randn(150, 768)
233+
grid_thw = torch.tensor(
234+
[
235+
[1, 5, 5], # 25 patches
236+
[2, 5, 5], # 50 patches
237+
[3, 5, 5], # 75 patches
238+
]
239+
)
240+
241+
image_assignments = [[0, 2], [1]]
242+
243+
# Rank 0 should have grids for images 0 and 2
244+
_, local_grid, _ = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=0)
245+
246+
assert local_grid.shape == (2, 3)
247+
assert torch.equal(local_grid[0], grid_thw[0])
248+
assert torch.equal(local_grid[1], grid_thw[2])
249+
250+
251+
class TestIntegration:
252+
"""Integration tests combining multiple functions."""
253+
254+
def test_full_workflow(self):
255+
"""Test the complete workflow of image distribution."""
256+
# Simulate 5 images with different sizes
257+
grid_thw = torch.tensor(
258+
[
259+
[1, 4, 4], # 16 patches
260+
[1, 8, 8], # 64 patches
261+
[1, 4, 4], # 16 patches
262+
[1, 6, 6], # 36 patches
263+
[1, 4, 4], # 16 patches
264+
]
265+
)
266+
267+
total_patches = 16 + 64 + 16 + 36 + 16 # 148 patches
268+
pixel_values = torch.randn(total_patches, 768)
269+
270+
# Step 1: Get patch counts
271+
patch_counts = get_image_patch_counts(grid_thw)
272+
assert patch_counts == [16, 64, 16, 36, 16]
273+
274+
# Step 2: Assign images to 2 ranks
275+
assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=2)
276+
277+
# Verify all images are assigned
278+
all_assigned = []
279+
for a in assignments:
280+
all_assigned.extend(a)
281+
assert sorted(all_assigned) == [0, 1, 2, 3, 4]
282+
283+
# Step 3: Extract local inputs for each rank
284+
total_local_patches = 0
285+
for rank in range(2):
286+
local_pix, local_grid, local_indices = prepare_local_vision_inputs(
287+
pixel_values, grid_thw, assignments, dp_rank=rank
288+
)
289+
290+
# Verify consistency
291+
expected_patches = sum(patch_counts[i] for i in local_indices)
292+
assert local_pix.shape[0] == expected_patches
293+
assert local_grid.shape[0] == len(local_indices)
294+
295+
total_local_patches += local_pix.shape[0]
296+
297+
# Total patches across all ranks should equal original
298+
assert total_local_patches == total_patches
299+
300+
def test_same_size_images(self):
301+
"""Test with all same-size images (user's scenario)."""
302+
num_images = 50
303+
patch_per_image = 64 # 8x8 patches
304+
305+
grid_thw = torch.tensor([[1, 8, 8]] * num_images)
306+
total_patches = num_images * patch_per_image
307+
_ = torch.randn(total_patches, 768)
308+
309+
patch_counts = get_image_patch_counts(grid_thw)
310+
assert all(c == 64 for c in patch_counts)
311+
312+
# With 4 DP ranks
313+
assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=4)
314+
315+
# Each rank should get approximately 12-13 images
316+
for rank in range(4):
317+
assert 12 <= len(assignments[rank]) <= 13
318+
319+
# Loads should be balanced (either 12*64=768 or 13*64=832)
320+
for load in loads:
321+
assert load in [768, 832]
322+
323+
324+
if __name__ == "__main__":
325+
pytest.main([__file__, "-v"])

verl/models/transformers/monkey_patch.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,38 @@ def state_dict(self, *args, **kwargs):
403403
patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLTextModel)
404404
patch_vlm_for_ulysses_input_slicing(Qwen2VLTextModel)
405405

406+
# Step 4: patch VisionTransformer for Vision DP (image-level distribution)
407+
if ulysses_sp_size > 1:
408+
from verl.utils.vision_dp import create_dp_vision_forward
409+
410+
# Patch Qwen2-VL VisionTransformer
411+
try:
412+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
413+
414+
original_vision_forward = Qwen2VisionTransformerPretrainedModel.forward
415+
Qwen2VisionTransformerPretrainedModel.forward = create_dp_vision_forward(original_vision_forward)
416+
print(
417+
f"Monkey patch Qwen2VisionTransformerPretrainedModel.forward"
418+
f" for Vision DP (dp_size={ulysses_sp_size})"
419+
)
420+
except ImportError as e:
421+
print(f"Warning: Could not patch Qwen2VisionTransformer for Vision DP: {e}")
422+
423+
# Patch Qwen2.5-VL VisionTransformer (uses a different class)
424+
try:
425+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
426+
Qwen2_5_VisionTransformerPretrainedModel,
427+
)
428+
429+
original_vision_forward_25 = Qwen2_5_VisionTransformerPretrainedModel.forward
430+
Qwen2_5_VisionTransformerPretrainedModel.forward = create_dp_vision_forward(original_vision_forward_25)
431+
print(
432+
f"Monkey patch Qwen2_5_VisionTransformerPretrainedModel.forward"
433+
f" for Vision DP (dp_size={ulysses_sp_size})"
434+
)
435+
except ImportError as e:
436+
print(f"Warning: Could not patch Qwen2_5VisionTransformer for Vision DP: {e}")
437+
406438
elif model.config.model_type in ["qwen3_vl", "qwen3_vl_moe"]:
407439
# Step 1: patch model to support image-text mixed data
408440
from transformers.models.qwen3_vl.modeling_qwen3_vl import (
@@ -437,6 +469,30 @@ def state_dict(self, *args, **kwargs):
437469
patch_vlm_for_ulysses_input_slicing(Qwen3VLTextModel)
438470
patch_vlm_for_ulysses_input_slicing(Qwen3VLMoeTextModel)
439471

472+
# Step 3: patch VisionTransformer for Vision DP (image-level distribution)
473+
if ulysses_sp_size > 1:
474+
from verl.utils.vision_dp import create_dp_vision_forward
475+
476+
# Patch Qwen3-VL VisionModel
477+
try:
478+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel
479+
480+
original_vision_forward_q3 = Qwen3VLVisionModel.forward
481+
Qwen3VLVisionModel.forward = create_dp_vision_forward(original_vision_forward_q3)
482+
print(f"Monkey patch Qwen3VLVisionModel.forward for Vision DP (dp_size={ulysses_sp_size})")
483+
except ImportError as e:
484+
print(f"Warning: Could not patch Qwen3VLVisionModel for Vision DP: {e}")
485+
486+
# Patch Qwen3-VL-MoE VisionModel
487+
try:
488+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeVisionModel
489+
490+
original_vision_forward_q3moe = Qwen3VLMoeVisionModel.forward
491+
Qwen3VLMoeVisionModel.forward = create_dp_vision_forward(original_vision_forward_q3moe)
492+
print(f"Monkey patch Qwen3VLMoeVisionModel.forward for Vision DP (dp_size={ulysses_sp_size})")
493+
except ImportError as e:
494+
print(f"Warning: Could not patch Qwen3VLMoeVisionModel for Vision DP: {e}")
495+
440496
elif model.config.model_type == "glm4v":
441497
# Step 1: patch model to support image-text mixed data
442498

0 commit comments

Comments
 (0)