Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 213 additions & 0 deletions tests/utils/test_vision_dp_on_cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Unit tests for Vision Data Parallel utilities (CPU-only, no distributed).
"""

import pytest
import torch

from verl.utils.vision_dp import (
assign_images_to_dp_ranks,
gather_vision_embeddings,
get_image_embedding_counts,
get_image_patch_counts,
prepare_local_vision_inputs,
)


class TestGetImagePatchCounts:
@pytest.mark.parametrize(
"grid_thw,expected",
[
([[2, 4, 4], [1, 2, 2], [1, 8, 8]], [32, 4, 64]),
([[1, 4, 4]], [16]),
([[4, 4, 4]], [64]),
],
ids=["multi-image", "single-image", "video-frames"],
)
def test_patch_counts_various_grids_correct_products(self, grid_thw, expected):
counts = get_image_patch_counts(torch.tensor(grid_thw))
assert counts == expected

def test_patch_counts_empty_input_returns_empty_list(self):
counts = get_image_patch_counts(torch.empty((0, 3), dtype=torch.long))
assert counts == []


class TestGetImageEmbeddingCounts:
@pytest.mark.parametrize(
"grid_thw,merge_size,expected",
[
([[1, 8, 8]], 1, [64]),
([[1, 8, 8]], 2, [16]),
([[1, 6, 6], [1, 4, 4]], 2, [9, 4]),
],
ids=["no-merge", "merge-2", "multi-image-merge"],
)
def test_embedding_counts_with_merge_size_correct(self, grid_thw, merge_size, expected):
counts = get_image_embedding_counts(torch.tensor(grid_thw), merge_size)
assert counts == expected


class TestAssignImagesToDpRanks:
@pytest.mark.parametrize(
"patch_counts,dp_size,expected_all_assigned",
[
([100, 100, 100, 100], 2, True),
([100, 200, 300], 1, True),
([100, 100, 100, 100, 100, 100], 3, True),
],
ids=["balanced-2ranks", "single-rank", "balanced-3ranks"],
)
def test_assign_all_images_distributed(self, patch_counts, dp_size, expected_all_assigned):
assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size)
all_assigned = []
for a in assignments:
all_assigned.extend(a)
assert sorted(all_assigned) == list(range(len(patch_counts)))
assert sum(loads) == sum(patch_counts)

def test_assign_fewer_images_than_ranks_all_assigned(self):
assignments, loads = assign_images_to_dp_ranks([100, 200], dp_size=4)
non_empty = sum(1 for a in assignments if len(a) > 0)
assert non_empty == 2
all_assigned = set()
for a in assignments:
all_assigned.update(a)
assert all_assigned == {0, 1}

def test_assign_empty_input_returns_empty(self):
assignments, loads = assign_images_to_dp_ranks([], dp_size=4)
assert all(len(a) == 0 for a in assignments)
assert all(load == 0 for load in loads)

def test_assign_image_order_preserved_contiguous(self):
assignments, _ = assign_images_to_dp_ranks([10, 20, 30, 40, 50], dp_size=2)
for rank_assignment in assignments:
assert rank_assignment == sorted(rank_assignment)

def test_assign_load_balanced_unequal_patches(self):
"""With unequal patch counts, greedy balancing should reduce imbalance."""
# 4096 + 256 + 256 + 256 = 4864, target per rank = 2432
patch_counts = [4096, 256, 256, 256]
assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=2)
# All images must be assigned
all_assigned = []
for a in assignments:
all_assigned.extend(a)
assert sorted(all_assigned) == [0, 1, 2, 3]
# Load imbalance should be less than the naive count-based split (8.5x)
max_load = max(loads)
min_load = min(load for load in loads if load > 0)
assert max_load / min_load < 8.0


class TestPrepareLocalVisionInputs:
def test_prepare_two_images_splits_correctly(self):
pixel_values = torch.randn(100, 768)
grid_thw = torch.tensor([[1, 6, 6], [1, 8, 8]]) # 36 + 64 = 100
image_assignments = [[0], [1]]

# Rank 0
pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=0)
assert pix.shape[0] == 36
assert grid.shape[0] == 1
assert indices == [0]
assert torch.allclose(pix, pixel_values[:36])

# Rank 1
pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=1)
assert pix.shape[0] == 64
assert indices == [1]
assert torch.allclose(pix, pixel_values[36:100])

def test_prepare_multiple_contiguous_images_per_rank(self):
pixel_values = torch.randn(200, 768)
grid_thw = torch.tensor([[1, 5, 10]] * 4) # 4 x 50 patches
image_assignments = [[0, 1], [2, 3]]

pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=0)
assert pix.shape[0] == 100
assert grid.shape[0] == 2
assert indices == [0, 1]
assert torch.allclose(pix, pixel_values[:100])

def test_prepare_empty_rank_returns_empty(self):
pixel_values = torch.randn(100, 768)
grid_thw = torch.tensor([[1, 10, 10]])
image_assignments = [[0], []]

pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=1)
assert pix.shape[0] == 0
assert grid.shape[0] == 0
assert indices == []

def test_prepare_grid_thw_preserved(self):
pixel_values = torch.randn(150, 768)
grid_thw = torch.tensor([[1, 5, 5], [2, 5, 5], [3, 5, 5]]) # 25 + 50 + 75
image_assignments = [[0, 1], [2]]

_, local_grid, _ = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=0)
assert local_grid.shape == (2, 3)
assert torch.equal(local_grid[0], grid_thw[0])
assert torch.equal(local_grid[1], grid_thw[1])


class TestGatherVisionEmbeddings:
def test_gather_none_group_returns_input(self):
embeddings = torch.randn(10, 64)
result = gather_vision_embeddings(embeddings, dp_group=None, all_counts=[10])
assert torch.equal(result, embeddings)


class TestIntegration:
def test_full_workflow_all_patches_covered(self):
grid_thw = torch.tensor([[1, 4, 4], [1, 8, 8], [1, 4, 4], [1, 6, 6], [1, 4, 4]])
total_patches = 16 + 64 + 16 + 36 + 16 # 148
pixel_values = torch.randn(total_patches, 768)

patch_counts = get_image_patch_counts(grid_thw)
assert patch_counts == [16, 64, 16, 36, 16]

assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=2)
all_assigned = []
for a in assignments:
all_assigned.extend(a)
assert sorted(all_assigned) == [0, 1, 2, 3, 4]

total_local_patches = 0
for rank in range(2):
pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, assignments, dp_rank=rank)
expected = sum(patch_counts[i] for i in indices)
assert pix.shape[0] == expected
assert grid.shape[0] == len(indices)
total_local_patches += pix.shape[0]

assert total_local_patches == total_patches

def test_same_size_images_4_ranks_balanced(self):
num_images = 50
grid_thw = torch.tensor([[1, 8, 8]] * num_images)
patch_counts = get_image_patch_counts(grid_thw)
assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=4)

for rank in range(4):
assert 12 <= len(assignments[rank]) <= 13
for load in loads:
assert load in [768, 832]


if __name__ == "__main__":
pytest.main([__file__, "-v"])
83 changes: 83 additions & 0 deletions verl/models/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ def apply_monkey_patch(
use_prefix_grouper: bool = False,
use_tiled_mlp: bool = False,
tiled_mlp_shards: int = 4,
vision_dp: bool = False,
):
"""
Apply monkey patch to the models for ulysses sequence parallel, fused kernel, tiled MLP and prefix grouper.
Expand Down Expand Up @@ -403,6 +404,52 @@ def state_dict(self, *args, **kwargs):
patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLTextModel)
patch_vlm_for_ulysses_input_slicing(Qwen2VLTextModel)

# Step 4: patch VisionTransformer for Vision DP (image-level distribution)
if ulysses_sp_size > 1:
if vision_dp:
from verl.utils.vision_dp import create_dp_vision_forward

# Patch Qwen2-VL VisionTransformer
try:
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel

if not getattr(Qwen2VisionTransformerPretrainedModel, "_vision_dp_patched", False):
original_vision_forward = Qwen2VisionTransformerPretrainedModel.forward
Qwen2VisionTransformerPretrainedModel.forward = create_dp_vision_forward(
original_vision_forward
)
Qwen2VisionTransformerPretrainedModel._vision_dp_patched = True
print(
f"Monkey patch Qwen2VisionTransformerPretrainedModel.forward"
f" for Vision DP (dp_size={ulysses_sp_size})"
)
except ImportError as e:
print(f"Warning: Could not patch Qwen2VisionTransformer for Vision DP: {e}")

# Patch Qwen2.5-VL VisionTransformer (uses a different class)
try:
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionTransformerPretrainedModel,
)

if not getattr(Qwen2_5_VisionTransformerPretrainedModel, "_vision_dp_patched", False):
original_vision_forward_25 = Qwen2_5_VisionTransformerPretrainedModel.forward
Qwen2_5_VisionTransformerPretrainedModel.forward = create_dp_vision_forward(
original_vision_forward_25
)
Qwen2_5_VisionTransformerPretrainedModel._vision_dp_patched = True
print(
f"Monkey patch Qwen2_5_VisionTransformerPretrainedModel.forward"
f" for Vision DP (dp_size={ulysses_sp_size})"
)
except ImportError as e:
print(f"Warning: Could not patch Qwen2_5VisionTransformer for Vision DP: {e}")
else:
print(
f"Vision DP disabled (vision_dp=False). "
f"ViT runs replicated on all {ulysses_sp_size} SP ranks."
)

Comment on lines +412 to +449
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block of code for monkey-patching the vision transformer is repeated multiple times in this file (here for Qwen2/2.5-VL, and again for Qwen3 models). This duplication makes the code harder to read and maintain. Please consider refactoring this logic into a helper function that takes the module path and class name as arguments. This would significantly reduce code duplication and improve maintainability.

For example, a helper could look like this:

def _patch_vision_model_for_dp(module_path, class_name, ulysses_sp_size):
    try:
        module = __import__(module_path, fromlist=[class_name])
        model_class = getattr(module, class_name)
        original_forward = model_class.forward
        model_class.forward = create_dp_vision_forward(original_forward)
        print(f"Monkey patch {class_name}.forward for Vision DP (dp_size={ulysses_sp_size})")
    except (ImportError, AttributeError) as e:
        print(f"Warning: Could not patch {class_name} for Vision DP: {e}")

Additionally, the step numbering is inconsistent ("Step 4" here, but "Step 3" for the Qwen3 models). A refactor would help resolve such inconsistencies.

Comment on lines +411 to +449
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is significant code duplication in this block and in the qwen3_vl block below (lines 487-522) for applying the Vision DP monkey patch. This repetitive logic can be refactored into a helper function to improve maintainability and reduce redundancy.

A helper function could encapsulate the try...except block, the import, the idempotency check, and the patching logic itself. This would make the code cleaner and easier to extend to other vision models in the future.

For example, you could define a helper function like this:

def _patch_vision_model_for_dp(model_class_name: str, module_path: str, ulysses_sp_size: int):
    """Applies the Vision DP monkey patch to a given model class."""
    try:
        from verl.utils.vision_dp import create_dp_vision_forward
        module = __import__(module_path, fromlist=[model_class_name])
        model_class = getattr(module, model_class_name)

        if not getattr(model_class, "_vision_dp_patched", False):
            original_forward = model_class.forward
            model_class.forward = create_dp_vision_forward(original_forward)
            model_class._vision_dp_patched = True
            print(
                f"Monkey patch {model_class_name}.forward for Vision DP (dp_size={ulysses_sp_size})"
            )
    except (ImportError, AttributeError) as e:
        print(f"Warning: Could not patch {model_class_name} for Vision DP: {e}")

And then call it for each model:

# Step 4: patch VisionTransformer for Vision DP (image-level distribution)
if ulysses_sp_size > 1:
    if vision_dp:
        _patch_vision_model_for_dp(
            "Qwen2VisionTransformerPretrainedModel",
            "transformers.models.qwen2_vl.modeling_qwen2_vl",
            ulysses_sp_size
        )
        _patch_vision_model_for_dp(
            "Qwen2_5_VisionTransformerPretrainedModel",
            "transformers.models.qwen2_5_vl.modeling_qwen2_5_vl",
            ulysses_sp_size
        )
    else:
        print(
            f"Vision DP disabled (vision_dp=False). "
            f"ViT runs replicated on all {ulysses_sp_size} SP ranks."
        )

This refactoring would also apply to the qwen3_vl block.

elif model.config.model_type in ["qwen3_vl", "qwen3_vl_moe"]:
# Step 1: patch model to support image-text mixed data
from transformers.models.qwen3_vl.modeling_qwen3_vl import (
Expand Down Expand Up @@ -437,6 +484,42 @@ def state_dict(self, *args, **kwargs):
patch_vlm_for_ulysses_input_slicing(Qwen3VLTextModel)
patch_vlm_for_ulysses_input_slicing(Qwen3VLMoeTextModel)

# Step 3: patch VisionTransformer for Vision DP (image-level distribution)
if ulysses_sp_size > 1:
if vision_dp:
from verl.utils.vision_dp import create_dp_vision_forward

# Patch Qwen3-VL VisionModel
try:
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel

if not getattr(Qwen3VLVisionModel, "_vision_dp_patched", False):
original_vision_forward_q3 = Qwen3VLVisionModel.forward
Qwen3VLVisionModel.forward = create_dp_vision_forward(original_vision_forward_q3)
Qwen3VLVisionModel._vision_dp_patched = True
print(f"Monkey patch Qwen3VLVisionModel.forward for Vision DP (dp_size={ulysses_sp_size})")
except ImportError as e:
print(f"Warning: Could not patch Qwen3VLVisionModel for Vision DP: {e}")

# Patch Qwen3-VL-MoE VisionModel
try:
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeVisionModel

if not getattr(Qwen3VLMoeVisionModel, "_vision_dp_patched", False):
original_vision_forward_q3moe = Qwen3VLMoeVisionModel.forward
Qwen3VLMoeVisionModel.forward = create_dp_vision_forward(original_vision_forward_q3moe)
Qwen3VLMoeVisionModel._vision_dp_patched = True
print(
f"Monkey patch Qwen3VLMoeVisionModel.forward for Vision DP (dp_size={ulysses_sp_size})"
)
except ImportError as e:
print(f"Warning: Could not patch Qwen3VLMoeVisionModel for Vision DP: {e}")
else:
print(
f"Vision DP disabled (vision_dp=False). "
f"ViT runs replicated on all {ulysses_sp_size} SP ranks."
)

elif model.config.model_type == "glm4v":
# Step 1: patch model to support image-text mixed data

Expand Down
4 changes: 4 additions & 0 deletions verl/trainer/config/model/hf_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ tiled_mlp:
# number of shards to split the input. Higher values reduce peak memory but may slightly impact performance.
num_shards: 4

# whether to enable Vision DP (distribute ViT computation across Ulysses SP ranks by image).
# Only effective when ulysses_sequence_parallel_size > 1.
vision_dp: False

# MTP
mtp:

Expand Down
Loading