-
Notifications
You must be signed in to change notification settings - Fork 3.4k
feat(vision): add Vision DP for parallel ViT computation across SP ranks #5230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
8763b64
5f86648
1215c9d
b777e1c
b4dc330
92d7dfe
9289e8e
911dcb0
cbc53fa
8e6d326
9970a2a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,213 @@ | ||
| # Copyright 2024 Bytedance Ltd. and/or its affiliates | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """ | ||
| Unit tests for Vision Data Parallel utilities (CPU-only, no distributed). | ||
| """ | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from verl.utils.vision_dp import ( | ||
| assign_images_to_dp_ranks, | ||
| gather_vision_embeddings, | ||
| get_image_embedding_counts, | ||
| get_image_patch_counts, | ||
| prepare_local_vision_inputs, | ||
| ) | ||
|
|
||
|
|
||
| class TestGetImagePatchCounts: | ||
| @pytest.mark.parametrize( | ||
| "grid_thw,expected", | ||
| [ | ||
| ([[2, 4, 4], [1, 2, 2], [1, 8, 8]], [32, 4, 64]), | ||
| ([[1, 4, 4]], [16]), | ||
| ([[4, 4, 4]], [64]), | ||
| ], | ||
| ids=["multi-image", "single-image", "video-frames"], | ||
| ) | ||
| def test_patch_counts_various_grids_correct_products(self, grid_thw, expected): | ||
| counts = get_image_patch_counts(torch.tensor(grid_thw)) | ||
| assert counts == expected | ||
|
|
||
| def test_patch_counts_empty_input_returns_empty_list(self): | ||
| counts = get_image_patch_counts(torch.empty((0, 3), dtype=torch.long)) | ||
| assert counts == [] | ||
|
|
||
|
|
||
| class TestGetImageEmbeddingCounts: | ||
| @pytest.mark.parametrize( | ||
| "grid_thw,merge_size,expected", | ||
| [ | ||
| ([[1, 8, 8]], 1, [64]), | ||
| ([[1, 8, 8]], 2, [16]), | ||
| ([[1, 6, 6], [1, 4, 4]], 2, [9, 4]), | ||
| ], | ||
| ids=["no-merge", "merge-2", "multi-image-merge"], | ||
| ) | ||
| def test_embedding_counts_with_merge_size_correct(self, grid_thw, merge_size, expected): | ||
| counts = get_image_embedding_counts(torch.tensor(grid_thw), merge_size) | ||
| assert counts == expected | ||
|
|
||
|
|
||
| class TestAssignImagesToDpRanks: | ||
| @pytest.mark.parametrize( | ||
| "patch_counts,dp_size,expected_all_assigned", | ||
| [ | ||
| ([100, 100, 100, 100], 2, True), | ||
| ([100, 200, 300], 1, True), | ||
| ([100, 100, 100, 100, 100, 100], 3, True), | ||
| ], | ||
| ids=["balanced-2ranks", "single-rank", "balanced-3ranks"], | ||
| ) | ||
| def test_assign_all_images_distributed(self, patch_counts, dp_size, expected_all_assigned): | ||
| assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size) | ||
| all_assigned = [] | ||
| for a in assignments: | ||
| all_assigned.extend(a) | ||
| assert sorted(all_assigned) == list(range(len(patch_counts))) | ||
| assert sum(loads) == sum(patch_counts) | ||
|
|
||
| def test_assign_fewer_images_than_ranks_all_assigned(self): | ||
| assignments, loads = assign_images_to_dp_ranks([100, 200], dp_size=4) | ||
| non_empty = sum(1 for a in assignments if len(a) > 0) | ||
| assert non_empty == 2 | ||
| all_assigned = set() | ||
| for a in assignments: | ||
| all_assigned.update(a) | ||
| assert all_assigned == {0, 1} | ||
|
|
||
| def test_assign_empty_input_returns_empty(self): | ||
| assignments, loads = assign_images_to_dp_ranks([], dp_size=4) | ||
| assert all(len(a) == 0 for a in assignments) | ||
| assert all(load == 0 for load in loads) | ||
|
|
||
| def test_assign_image_order_preserved_contiguous(self): | ||
| assignments, _ = assign_images_to_dp_ranks([10, 20, 30, 40, 50], dp_size=2) | ||
| for rank_assignment in assignments: | ||
| assert rank_assignment == sorted(rank_assignment) | ||
|
|
||
| def test_assign_load_balanced_unequal_patches(self): | ||
| """With unequal patch counts, greedy balancing should reduce imbalance.""" | ||
| # 4096 + 256 + 256 + 256 = 4864, target per rank = 2432 | ||
| patch_counts = [4096, 256, 256, 256] | ||
| assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=2) | ||
| # All images must be assigned | ||
| all_assigned = [] | ||
| for a in assignments: | ||
| all_assigned.extend(a) | ||
| assert sorted(all_assigned) == [0, 1, 2, 3] | ||
| # Load imbalance should be less than the naive count-based split (8.5x) | ||
| max_load = max(loads) | ||
| min_load = min(load for load in loads if load > 0) | ||
| assert max_load / min_load < 8.0 | ||
|
|
||
|
|
||
| class TestPrepareLocalVisionInputs: | ||
| def test_prepare_two_images_splits_correctly(self): | ||
| pixel_values = torch.randn(100, 768) | ||
| grid_thw = torch.tensor([[1, 6, 6], [1, 8, 8]]) # 36 + 64 = 100 | ||
| image_assignments = [[0], [1]] | ||
|
|
||
| # Rank 0 | ||
| pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=0) | ||
| assert pix.shape[0] == 36 | ||
| assert grid.shape[0] == 1 | ||
| assert indices == [0] | ||
| assert torch.allclose(pix, pixel_values[:36]) | ||
|
|
||
| # Rank 1 | ||
| pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=1) | ||
| assert pix.shape[0] == 64 | ||
| assert indices == [1] | ||
| assert torch.allclose(pix, pixel_values[36:100]) | ||
|
|
||
| def test_prepare_multiple_contiguous_images_per_rank(self): | ||
| pixel_values = torch.randn(200, 768) | ||
| grid_thw = torch.tensor([[1, 5, 10]] * 4) # 4 x 50 patches | ||
| image_assignments = [[0, 1], [2, 3]] | ||
|
|
||
| pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=0) | ||
| assert pix.shape[0] == 100 | ||
| assert grid.shape[0] == 2 | ||
| assert indices == [0, 1] | ||
| assert torch.allclose(pix, pixel_values[:100]) | ||
|
|
||
| def test_prepare_empty_rank_returns_empty(self): | ||
| pixel_values = torch.randn(100, 768) | ||
| grid_thw = torch.tensor([[1, 10, 10]]) | ||
| image_assignments = [[0], []] | ||
|
|
||
| pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=1) | ||
| assert pix.shape[0] == 0 | ||
| assert grid.shape[0] == 0 | ||
| assert indices == [] | ||
|
|
||
| def test_prepare_grid_thw_preserved(self): | ||
| pixel_values = torch.randn(150, 768) | ||
| grid_thw = torch.tensor([[1, 5, 5], [2, 5, 5], [3, 5, 5]]) # 25 + 50 + 75 | ||
| image_assignments = [[0, 1], [2]] | ||
|
|
||
| _, local_grid, _ = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=0) | ||
| assert local_grid.shape == (2, 3) | ||
| assert torch.equal(local_grid[0], grid_thw[0]) | ||
| assert torch.equal(local_grid[1], grid_thw[1]) | ||
|
|
||
|
|
||
| class TestGatherVisionEmbeddings: | ||
| def test_gather_none_group_returns_input(self): | ||
| embeddings = torch.randn(10, 64) | ||
| result = gather_vision_embeddings(embeddings, dp_group=None, all_counts=[10]) | ||
| assert torch.equal(result, embeddings) | ||
|
|
||
|
|
||
| class TestIntegration: | ||
| def test_full_workflow_all_patches_covered(self): | ||
| grid_thw = torch.tensor([[1, 4, 4], [1, 8, 8], [1, 4, 4], [1, 6, 6], [1, 4, 4]]) | ||
| total_patches = 16 + 64 + 16 + 36 + 16 # 148 | ||
| pixel_values = torch.randn(total_patches, 768) | ||
|
|
||
| patch_counts = get_image_patch_counts(grid_thw) | ||
| assert patch_counts == [16, 64, 16, 36, 16] | ||
|
|
||
| assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=2) | ||
| all_assigned = [] | ||
| for a in assignments: | ||
| all_assigned.extend(a) | ||
| assert sorted(all_assigned) == [0, 1, 2, 3, 4] | ||
|
|
||
| total_local_patches = 0 | ||
| for rank in range(2): | ||
| pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, assignments, dp_rank=rank) | ||
| expected = sum(patch_counts[i] for i in indices) | ||
| assert pix.shape[0] == expected | ||
| assert grid.shape[0] == len(indices) | ||
| total_local_patches += pix.shape[0] | ||
|
|
||
| assert total_local_patches == total_patches | ||
|
|
||
| def test_same_size_images_4_ranks_balanced(self): | ||
| num_images = 50 | ||
| grid_thw = torch.tensor([[1, 8, 8]] * num_images) | ||
| patch_counts = get_image_patch_counts(grid_thw) | ||
| assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=4) | ||
|
|
||
| for rank in range(4): | ||
| assert 12 <= len(assignments[rank]) <= 13 | ||
| for load in loads: | ||
| assert load in [768, 832] | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| pytest.main([__file__, "-v"]) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -292,6 +292,7 @@ def apply_monkey_patch( | |
| use_prefix_grouper: bool = False, | ||
| use_tiled_mlp: bool = False, | ||
| tiled_mlp_shards: int = 4, | ||
| vision_dp: bool = False, | ||
| ): | ||
| """ | ||
| Apply monkey patch to the models for ulysses sequence parallel, fused kernel, tiled MLP and prefix grouper. | ||
|
|
@@ -403,6 +404,52 @@ def state_dict(self, *args, **kwargs): | |
| patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLTextModel) | ||
| patch_vlm_for_ulysses_input_slicing(Qwen2VLTextModel) | ||
|
|
||
| # Step 4: patch VisionTransformer for Vision DP (image-level distribution) | ||
| if ulysses_sp_size > 1: | ||
| if vision_dp: | ||
| from verl.utils.vision_dp import create_dp_vision_forward | ||
|
|
||
| # Patch Qwen2-VL VisionTransformer | ||
| try: | ||
| from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel | ||
|
|
||
| if not getattr(Qwen2VisionTransformerPretrainedModel, "_vision_dp_patched", False): | ||
| original_vision_forward = Qwen2VisionTransformerPretrainedModel.forward | ||
| Qwen2VisionTransformerPretrainedModel.forward = create_dp_vision_forward( | ||
| original_vision_forward | ||
| ) | ||
| Qwen2VisionTransformerPretrainedModel._vision_dp_patched = True | ||
| print( | ||
| f"Monkey patch Qwen2VisionTransformerPretrainedModel.forward" | ||
| f" for Vision DP (dp_size={ulysses_sp_size})" | ||
| ) | ||
| except ImportError as e: | ||
| print(f"Warning: Could not patch Qwen2VisionTransformer for Vision DP: {e}") | ||
|
|
||
| # Patch Qwen2.5-VL VisionTransformer (uses a different class) | ||
| try: | ||
| from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( | ||
| Qwen2_5_VisionTransformerPretrainedModel, | ||
| ) | ||
|
|
||
| if not getattr(Qwen2_5_VisionTransformerPretrainedModel, "_vision_dp_patched", False): | ||
| original_vision_forward_25 = Qwen2_5_VisionTransformerPretrainedModel.forward | ||
| Qwen2_5_VisionTransformerPretrainedModel.forward = create_dp_vision_forward( | ||
| original_vision_forward_25 | ||
| ) | ||
| Qwen2_5_VisionTransformerPretrainedModel._vision_dp_patched = True | ||
| print( | ||
| f"Monkey patch Qwen2_5_VisionTransformerPretrainedModel.forward" | ||
| f" for Vision DP (dp_size={ulysses_sp_size})" | ||
| ) | ||
| except ImportError as e: | ||
| print(f"Warning: Could not patch Qwen2_5VisionTransformer for Vision DP: {e}") | ||
| else: | ||
| print( | ||
| f"Vision DP disabled (vision_dp=False). " | ||
| f"ViT runs replicated on all {ulysses_sp_size} SP ranks." | ||
| ) | ||
|
|
||
|
Comment on lines
+411
to
+449
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is significant code duplication in this block and in the A helper function could encapsulate the For example, you could define a helper function like this: def _patch_vision_model_for_dp(model_class_name: str, module_path: str, ulysses_sp_size: int):
"""Applies the Vision DP monkey patch to a given model class."""
try:
from verl.utils.vision_dp import create_dp_vision_forward
module = __import__(module_path, fromlist=[model_class_name])
model_class = getattr(module, model_class_name)
if not getattr(model_class, "_vision_dp_patched", False):
original_forward = model_class.forward
model_class.forward = create_dp_vision_forward(original_forward)
model_class._vision_dp_patched = True
print(
f"Monkey patch {model_class_name}.forward for Vision DP (dp_size={ulysses_sp_size})"
)
except (ImportError, AttributeError) as e:
print(f"Warning: Could not patch {model_class_name} for Vision DP: {e}")And then call it for each model: # Step 4: patch VisionTransformer for Vision DP (image-level distribution)
if ulysses_sp_size > 1:
if vision_dp:
_patch_vision_model_for_dp(
"Qwen2VisionTransformerPretrainedModel",
"transformers.models.qwen2_vl.modeling_qwen2_vl",
ulysses_sp_size
)
_patch_vision_model_for_dp(
"Qwen2_5_VisionTransformerPretrainedModel",
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl",
ulysses_sp_size
)
else:
print(
f"Vision DP disabled (vision_dp=False). "
f"ViT runs replicated on all {ulysses_sp_size} SP ranks."
)This refactoring would also apply to the |
||
| elif model.config.model_type in ["qwen3_vl", "qwen3_vl_moe"]: | ||
| # Step 1: patch model to support image-text mixed data | ||
| from transformers.models.qwen3_vl.modeling_qwen3_vl import ( | ||
|
|
@@ -437,6 +484,42 @@ def state_dict(self, *args, **kwargs): | |
| patch_vlm_for_ulysses_input_slicing(Qwen3VLTextModel) | ||
| patch_vlm_for_ulysses_input_slicing(Qwen3VLMoeTextModel) | ||
|
|
||
| # Step 3: patch VisionTransformer for Vision DP (image-level distribution) | ||
| if ulysses_sp_size > 1: | ||
| if vision_dp: | ||
| from verl.utils.vision_dp import create_dp_vision_forward | ||
|
|
||
| # Patch Qwen3-VL VisionModel | ||
| try: | ||
| from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel | ||
|
|
||
| if not getattr(Qwen3VLVisionModel, "_vision_dp_patched", False): | ||
| original_vision_forward_q3 = Qwen3VLVisionModel.forward | ||
| Qwen3VLVisionModel.forward = create_dp_vision_forward(original_vision_forward_q3) | ||
| Qwen3VLVisionModel._vision_dp_patched = True | ||
| print(f"Monkey patch Qwen3VLVisionModel.forward for Vision DP (dp_size={ulysses_sp_size})") | ||
| except ImportError as e: | ||
| print(f"Warning: Could not patch Qwen3VLVisionModel for Vision DP: {e}") | ||
|
|
||
| # Patch Qwen3-VL-MoE VisionModel | ||
| try: | ||
| from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeVisionModel | ||
|
|
||
| if not getattr(Qwen3VLMoeVisionModel, "_vision_dp_patched", False): | ||
| original_vision_forward_q3moe = Qwen3VLMoeVisionModel.forward | ||
| Qwen3VLMoeVisionModel.forward = create_dp_vision_forward(original_vision_forward_q3moe) | ||
| Qwen3VLMoeVisionModel._vision_dp_patched = True | ||
| print( | ||
| f"Monkey patch Qwen3VLMoeVisionModel.forward for Vision DP (dp_size={ulysses_sp_size})" | ||
| ) | ||
| except ImportError as e: | ||
| print(f"Warning: Could not patch Qwen3VLMoeVisionModel for Vision DP: {e}") | ||
| else: | ||
| print( | ||
| f"Vision DP disabled (vision_dp=False). " | ||
| f"ViT runs replicated on all {ulysses_sp_size} SP ranks." | ||
| ) | ||
|
|
||
| elif model.config.model_type == "glm4v": | ||
| # Step 1: patch model to support image-text mixed data | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block of code for monkey-patching the vision transformer is repeated multiple times in this file (here for Qwen2/2.5-VL, and again for Qwen3 models). This duplication makes the code harder to read and maintain. Please consider refactoring this logic into a helper function that takes the module path and class name as arguments. This would significantly reduce code duplication and improve maintainability.
For example, a helper could look like this:
Additionally, the step numbering is inconsistent ("Step 4" here, but "Step 3" for the Qwen3 models). A refactor would help resolve such inconsistencies.