-
Notifications
You must be signed in to change notification settings - Fork 163
[parallel] feat: Vision Data Parallel — O(1) communication alternative to patch-level SP #505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
aoshen524
wants to merge
5
commits into
ByteDance-Seed:main
Choose a base branch
from
aoshen524:feat/vision-dp
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+754
−16
Open
Changes from 2 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
74d3a6b
[parallel] feat: add Vision Data Parallel as alternative to patch-lev…
aoshen524 6ea042b
fix: pass pre-computed patch_counts to prepare_local_vision_inputs
aoshen524 817be83
refactor(vision_dp): simplify docstrings, fix empty-rank backward, ad…
aoshen524 1b9efc2
fix(vision_dp): fix Qwen3-VL deepstack NCCL deadlock on empty ranks
aoshen524 4783f2c
feat(vision_dp): add vision_dp flag to gate Vision DP patching
aoshen524 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,95 @@ | ||
| """Stub heavy dependencies to allow CPU-only testing of vision_dp utilities. | ||
|
|
||
| The veomni import chain pulls in datasets, flash_attn, CUDA ops, etc. | ||
| We stub everything except the specific module under test (vision_dp.py) | ||
| so that pytest can collect and run the tests on a CPU-only machine. | ||
| """ | ||
|
|
||
| import importlib | ||
| import sys | ||
| import types | ||
|
|
||
|
|
||
| def _ensure_stub(name, **attrs): | ||
| """Create a stub module with __path__ if it doesn't already exist.""" | ||
| if name in sys.modules: | ||
| mod = sys.modules[name] | ||
| else: | ||
| mod = types.ModuleType(name) | ||
| mod.__path__ = [name.replace(".", "/")] | ||
| sys.modules[name] = mod | ||
| for k, v in attrs.items(): | ||
| setattr(mod, k, v) | ||
| return mod | ||
|
|
||
|
|
||
| # ── Stub veomni top-level (prevent __init__.py from importing ops/data) ── | ||
| _ensure_stub("veomni") | ||
|
|
||
| _ensure_stub("veomni.ops") | ||
| _ensure_stub("veomni.data") | ||
| _ensure_stub("veomni.data.constants", IGNORE_INDEX=-100) | ||
|
|
||
| # ── Stub utils ── | ||
| _ensure_stub("veomni.utils") | ||
|
|
||
|
|
||
| class _FakeLogger: | ||
| def __getattr__(self, name): | ||
| return lambda *a, **kw: None | ||
|
|
||
|
|
||
| _ensure_stub("veomni.utils.logging", get_logger=lambda name=None: _FakeLogger()) | ||
| _ensure_stub( | ||
| "veomni.utils.device", | ||
| get_device_type=lambda: "cpu", | ||
| get_device_id=lambda: "cpu", | ||
| IS_NPU_AVAILABLE=False, | ||
| IS_CUDA_AVAILABLE=False, | ||
| ) | ||
| _ensure_stub("veomni.utils.import_utils", is_torch_version_greater_than=lambda v: True) | ||
|
|
||
| # ── Stub distributed ── | ||
| _ensure_stub("veomni.distributed") | ||
|
|
||
|
|
||
| class _FakeParallelState: | ||
| sp_enabled = False | ||
| sp_size = 1 | ||
| sp_rank = 0 | ||
| sp_group = None | ||
|
|
||
|
|
||
| _ensure_stub( | ||
| "veomni.distributed.parallel_state", | ||
| get_parallel_state=lambda: _FakeParallelState(), | ||
| ParallelState=_FakeParallelState, | ||
| ) | ||
|
|
||
| # ── Stub the sequence_parallel __init__ and its heavy sub-modules ── | ||
| # We need to prevent the real __init__.py from running (it imports | ||
| # async_ulysses, comm, data, loss, ulysses, utils which have heavy deps). | ||
| # So we register the package stub FIRST, then load vision_dp.py directly. | ||
| _sp_pkg = _ensure_stub("veomni.distributed.sequence_parallel") | ||
|
|
||
| for _sub in [ | ||
| "veomni.distributed.sequence_parallel.async_ulysses", | ||
| "veomni.distributed.sequence_parallel.comm", | ||
| "veomni.distributed.sequence_parallel.data", | ||
| "veomni.distributed.sequence_parallel.loss", | ||
| "veomni.distributed.sequence_parallel.ulysses", | ||
| "veomni.distributed.sequence_parallel.utils", | ||
| ]: | ||
| _ensure_stub(_sub) | ||
|
|
||
| # Now load vision_dp.py for real (it only depends on torch, dist, parallel_state) | ||
| _vision_dp_spec = importlib.util.spec_from_file_location( | ||
| "veomni.distributed.sequence_parallel.vision_dp", | ||
| "veomni/distributed/sequence_parallel/vision_dp.py", | ||
| ) | ||
| _vision_dp_mod = importlib.util.module_from_spec(_vision_dp_spec) | ||
| sys.modules["veomni.distributed.sequence_parallel.vision_dp"] = _vision_dp_mod | ||
| _vision_dp_spec.loader.exec_module(_vision_dp_mod) | ||
|
|
||
| # Attach it to the parent package | ||
| _sp_pkg.vision_dp = _vision_dp_mod |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,205 @@ | ||
| # Copyright 2025 Bytedance Ltd. and/or its affiliates | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """ | ||
| Unit tests for Vision Data Parallel utilities (CPU-only, no distributed). | ||
| """ | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from veomni.distributed.sequence_parallel.vision_dp import ( | ||
| assign_images_to_dp_ranks, | ||
| gather_vision_embeddings, | ||
| get_image_embedding_counts, | ||
| get_image_patch_counts, | ||
| prepare_local_vision_inputs, | ||
| ) | ||
|
|
||
|
|
||
| class TestGetImagePatchCounts: | ||
| @pytest.mark.parametrize( | ||
| "grid_thw,expected", | ||
| [ | ||
| ([[2, 4, 4], [1, 2, 2], [1, 8, 8]], [32, 4, 64]), | ||
| ([[1, 4, 4]], [16]), | ||
| ([[4, 4, 4]], [64]), | ||
| ], | ||
| ids=["multi-image", "single-image", "video-frames"], | ||
| ) | ||
| def test_patch_counts_various_grids_correct_products(self, grid_thw, expected): | ||
| counts = get_image_patch_counts(torch.tensor(grid_thw)) | ||
| assert counts == expected | ||
|
|
||
| def test_patch_counts_empty_input_returns_empty_list(self): | ||
| counts = get_image_patch_counts(torch.empty((0, 3), dtype=torch.long)) | ||
| assert counts == [] | ||
|
|
||
|
|
||
| class TestGetImageEmbeddingCounts: | ||
| @pytest.mark.parametrize( | ||
| "grid_thw,merge_size,expected", | ||
| [ | ||
| ([[1, 8, 8]], 1, [64]), | ||
| ([[1, 8, 8]], 2, [16]), | ||
| ([[1, 6, 6], [1, 4, 4]], 2, [9, 4]), | ||
| ], | ||
| ids=["no-merge", "merge-2", "multi-image-merge"], | ||
| ) | ||
| def test_embedding_counts_with_merge_size_correct(self, grid_thw, merge_size, expected): | ||
| counts = get_image_embedding_counts(torch.tensor(grid_thw), merge_size) | ||
| assert counts == expected | ||
|
|
||
|
|
||
| class TestAssignImagesToDpRanks: | ||
| @pytest.mark.parametrize( | ||
| "patch_counts,dp_size", | ||
| [ | ||
| ([100, 100, 100, 100], 2), | ||
| ([100, 200, 300], 1), | ||
| ([100, 100, 100, 100, 100, 100], 3), | ||
| ], | ||
| ids=["balanced-2ranks", "single-rank", "balanced-3ranks"], | ||
| ) | ||
| def test_assign_all_images_distributed(self, patch_counts, dp_size): | ||
| assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size) | ||
| all_assigned = [] | ||
| for a in assignments: | ||
| all_assigned.extend(a) | ||
| assert sorted(all_assigned) == list(range(len(patch_counts))) | ||
| assert sum(loads) == sum(patch_counts) | ||
|
|
||
| def test_assign_fewer_images_than_ranks_all_assigned(self): | ||
| assignments, loads = assign_images_to_dp_ranks([100, 200], dp_size=4) | ||
| non_empty = sum(1 for a in assignments if len(a) > 0) | ||
| assert non_empty == 2 | ||
| all_assigned = set() | ||
| for a in assignments: | ||
| all_assigned.update(a) | ||
| assert all_assigned == {0, 1} | ||
|
|
||
| def test_assign_empty_input_returns_empty(self): | ||
| assignments, loads = assign_images_to_dp_ranks([], dp_size=4) | ||
| assert all(len(a) == 0 for a in assignments) | ||
| assert all(load == 0 for load in loads) | ||
|
|
||
| def test_assign_image_order_preserved_contiguous(self): | ||
| assignments, _ = assign_images_to_dp_ranks([10, 20, 30, 40, 50], dp_size=2) | ||
| for rank_assignment in assignments: | ||
| assert rank_assignment == sorted(rank_assignment) | ||
|
|
||
| def test_assign_load_balanced_unequal_patches(self): | ||
| """With unequal patch counts, greedy balancing should reduce imbalance.""" | ||
| patch_counts = [4096, 256, 256, 256] | ||
| assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=2) | ||
| all_assigned = [] | ||
| for a in assignments: | ||
| all_assigned.extend(a) | ||
| assert sorted(all_assigned) == [0, 1, 2, 3] | ||
| max_load = max(loads) | ||
| min_load = min(load for load in loads if load > 0) | ||
| assert max_load / min_load < 8.0 | ||
|
|
||
|
|
||
| class TestPrepareLocalVisionInputs: | ||
| def test_prepare_two_images_splits_correctly(self): | ||
| pixel_values = torch.randn(100, 768) | ||
| grid_thw = torch.tensor([[1, 6, 6], [1, 8, 8]]) # 36 + 64 = 100 | ||
| image_assignments = [[0], [1]] | ||
|
|
||
| pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=0) | ||
| assert pix.shape[0] == 36 | ||
| assert grid.shape[0] == 1 | ||
| assert indices == [0] | ||
| assert torch.allclose(pix, pixel_values[:36]) | ||
|
|
||
| pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=1) | ||
| assert pix.shape[0] == 64 | ||
| assert grid.shape[0] == 1 | ||
| assert indices == [1] | ||
| assert torch.allclose(pix, pixel_values[36:100]) | ||
|
|
||
| def test_prepare_multiple_contiguous_images_per_rank(self): | ||
| pixel_values = torch.randn(200, 768) | ||
| grid_thw = torch.tensor([[1, 5, 10]] * 4) # 4 x 50 patches | ||
| image_assignments = [[0, 1], [2, 3]] | ||
|
|
||
| pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=0) | ||
| assert pix.shape[0] == 100 | ||
| assert grid.shape[0] == 2 | ||
| assert indices == [0, 1] | ||
| assert torch.allclose(pix, pixel_values[:100]) | ||
|
|
||
| def test_prepare_empty_rank_returns_empty(self): | ||
| pixel_values = torch.randn(100, 768) | ||
| grid_thw = torch.tensor([[1, 10, 10]]) | ||
| image_assignments = [[0], []] | ||
|
|
||
| pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=1) | ||
| assert pix.shape[0] == 0 | ||
| assert grid.shape[0] == 0 | ||
| assert indices == [] | ||
|
|
||
| def test_prepare_grid_thw_preserved(self): | ||
| pixel_values = torch.randn(150, 768) | ||
| grid_thw = torch.tensor([[1, 5, 5], [2, 5, 5], [3, 5, 5]]) # 25 + 50 + 75 | ||
| image_assignments = [[0, 1], [2]] | ||
|
|
||
| _, local_grid, _ = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=0) | ||
| assert local_grid.shape == (2, 3) | ||
| assert torch.equal(local_grid[0], grid_thw[0]) | ||
| assert torch.equal(local_grid[1], grid_thw[1]) | ||
|
|
||
|
|
||
| class TestGatherVisionEmbeddings: | ||
| def test_gather_none_group_returns_input(self): | ||
| embeddings = torch.randn(10, 64) | ||
| result = gather_vision_embeddings(embeddings, dp_group=None, all_counts=[10]) | ||
| assert torch.equal(result, embeddings) | ||
|
|
||
|
|
||
| class TestIntegration: | ||
| def test_full_workflow_all_patches_covered(self): | ||
| grid_thw = torch.tensor([[1, 4, 4], [1, 8, 8], [1, 4, 4], [1, 6, 6], [1, 4, 4]]) | ||
| total_patches = 16 + 64 + 16 + 36 + 16 # 148 | ||
| pixel_values = torch.randn(total_patches, 768) | ||
|
|
||
| patch_counts = get_image_patch_counts(grid_thw) | ||
| assert patch_counts == [16, 64, 16, 36, 16] | ||
|
|
||
| assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=2) | ||
| all_assigned = [] | ||
| for a in assignments: | ||
| all_assigned.extend(a) | ||
| assert sorted(all_assigned) == [0, 1, 2, 3, 4] | ||
|
|
||
| total_local_patches = 0 | ||
| for rank in range(2): | ||
| pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, assignments, dp_rank=rank) | ||
| expected = sum(patch_counts[i] for i in indices) | ||
| assert pix.shape[0] == expected | ||
| assert grid.shape[0] == len(indices) | ||
| total_local_patches += pix.shape[0] | ||
|
|
||
| assert total_local_patches == total_patches | ||
|
|
||
| def test_same_size_images_4_ranks_balanced(self): | ||
| num_images = 50 | ||
| grid_thw = torch.tensor([[1, 8, 8]] * num_images) | ||
| patch_counts = get_image_patch_counts(grid_thw) | ||
| assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=4) | ||
|
|
||
| for rank in range(4): | ||
| assert 12 <= len(assignments[rank]) <= 13 | ||
| for load in loads: | ||
| assert load in [768, 832] | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following the suggested change in
vision_dp.pyto passpatch_countstoprepare_local_vision_inputs, these tests need to be updated to pass the new argument. Note that the call intest_full_workflow_all_patches_coveredon line 188 will also need to be updated similarly.