-
Notifications
You must be signed in to change notification settings - Fork 417
feat(models): shard vision encoder across Ulysses SP ranks #929
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 1 commit
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
5625ae5
[Feature] Add Vision Data Parallel to distribute ViT computation acro…
aoshen524 73ae996
fix(vision-dp): address all review comments for Vision DP
aoshen524 a6944c5
refactor(vision_dp): simplify docstrings, fix empty-rank backward, ad…
aoshen524 b42049d
fix(vision_dp): add Qwen3-VL deepstack support with empty-rank safety
aoshen524 58bb743
feat(vision_dp): add vision_dp config flag to gate Vision DP patching
aoshen524 3311103
fix(vision_dp): replace tautological assertion with independent cross…
aoshen524 8ae9f06
test(vision_dp): improve test naming, add patch and forward coverage
aoshen524 2ac391f
Merge remote-tracking branch 'upstream/main' into feat/vision-dp
rchardx 0ea8b20
format
rchardx 9b69166
refactor: rename vision_dp to shard_vision_across_sp and fix bugs
rchardx 276a5af
refactor(models): move vision_sp_shard to models/fsdp and harden helpers
rchardx da84052
refactor(models): improve code quality of vision_sp_shard
rchardx 1111e54
refactor(models): move vision_sp_shard to transformers/ and improve t…
rchardx 5ac0f26
format
rchardx File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,185 @@ | ||
| """ | ||
| Unit tests for Vision Data Parallel utilities (CPU-only, no distributed). | ||
|
|
||
| Adapted from verl PR #5230 tests. | ||
| """ | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from areal.utils.vision_dp import ( | ||
| assign_images_to_dp_ranks, | ||
| get_image_embedding_counts, | ||
| get_image_patch_counts, | ||
| prepare_local_vision_inputs, | ||
| ) | ||
|
|
||
|
|
||
| class TestGetImagePatchCounts: | ||
| @pytest.mark.parametrize( | ||
| "grid_thw,expected", | ||
| [ | ||
| ([[2, 4, 4], [1, 2, 2], [1, 8, 8]], [32, 4, 64]), | ||
| ([[1, 4, 4]], [16]), | ||
| ([[4, 4, 4]], [64]), | ||
| ], | ||
| ids=["multi-image", "single-image", "video-frames"], | ||
| ) | ||
| def test_patch_counts(self, grid_thw, expected): | ||
| counts = get_image_patch_counts(torch.tensor(grid_thw)) | ||
| assert counts == expected | ||
|
|
||
| def test_empty_input(self): | ||
| counts = get_image_patch_counts(torch.empty((0, 3), dtype=torch.long)) | ||
| assert counts == [] | ||
|
|
||
|
|
||
| class TestGetImageEmbeddingCounts: | ||
| @pytest.mark.parametrize( | ||
| "grid_thw,merge_size,expected", | ||
| [ | ||
| ([[1, 8, 8]], 1, [64]), | ||
| ([[1, 8, 8]], 2, [16]), | ||
| ([[1, 6, 6], [1, 4, 4]], 2, [9, 4]), | ||
| ], | ||
| ids=["no-merge", "merge-2", "multi-image-merge"], | ||
| ) | ||
| def test_embedding_counts(self, grid_thw, merge_size, expected): | ||
| counts = get_image_embedding_counts(torch.tensor(grid_thw), merge_size) | ||
| assert counts == expected | ||
|
|
||
|
|
||
| class TestAssignImagesToDpRanks: | ||
| @pytest.mark.parametrize( | ||
| "patch_counts,dp_size,expected_lens,expected_loads", | ||
| [ | ||
| ([100, 100, 100, 100], 2, [2, 2], [200, 200]), | ||
rchardx marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ([100, 200, 300], 1, [3], [600]), | ||
| ([100, 100, 100, 100, 100, 100], 3, [2, 2, 2], [200, 200, 200]), | ||
| ], | ||
| ids=["balanced-2ranks", "single-rank", "balanced-3ranks"], | ||
| ) | ||
| def test_balanced(self, patch_counts, dp_size, expected_lens, expected_loads): | ||
| assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size) | ||
| assert [len(a) for a in assignments] == expected_lens | ||
| assert loads == expected_loads | ||
|
|
||
| def test_fewer_images_than_ranks(self): | ||
| assignments, loads = assign_images_to_dp_ranks([100, 200], dp_size=4) | ||
| non_empty = sum(1 for a in assignments if len(a) > 0) | ||
| assert non_empty == 2 | ||
| all_assigned = set() | ||
| for a in assignments: | ||
| all_assigned.update(a) | ||
| assert all_assigned == {0, 1} | ||
|
|
||
| def test_empty_input(self): | ||
| assignments, loads = assign_images_to_dp_ranks([], dp_size=4) | ||
| assert all(len(a) == 0 for a in assignments) | ||
| assert all(load == 0 for load in loads) | ||
|
|
||
| def test_image_order_preserved(self): | ||
| assignments, _ = assign_images_to_dp_ranks([10, 20, 30, 40, 50], dp_size=2) | ||
| for rank_assignment in assignments: | ||
| assert rank_assignment == sorted(rank_assignment) | ||
|
|
||
|
|
||
| class TestPrepareLocalVisionInputs: | ||
| def test_basic_extraction(self): | ||
| pixel_values = torch.randn(100, 768) | ||
| grid_thw = torch.tensor([[1, 6, 6], [1, 8, 8]]) # 36 + 64 = 100 | ||
| image_assignments = [[0], [1]] | ||
|
|
||
| # Rank 0 | ||
| pix, grid, indices = prepare_local_vision_inputs( | ||
| pixel_values, grid_thw, image_assignments, dp_rank=0 | ||
| ) | ||
| assert pix.shape[0] == 36 | ||
| assert grid.shape[0] == 1 | ||
| assert indices == [0] | ||
| assert torch.allclose(pix, pixel_values[:36]) | ||
|
|
||
| # Rank 1 | ||
| pix, grid, indices = prepare_local_vision_inputs( | ||
| pixel_values, grid_thw, image_assignments, dp_rank=1 | ||
| ) | ||
| assert pix.shape[0] == 64 | ||
| assert indices == [1] | ||
| assert torch.allclose(pix, pixel_values[36:100]) | ||
|
|
||
| def test_multiple_images_per_rank(self): | ||
| pixel_values = torch.randn(200, 768) | ||
| grid_thw = torch.tensor([[1, 5, 10]] * 4) # 4 x 50 patches | ||
| image_assignments = [[0, 2], [1, 3]] | ||
|
|
||
| pix, grid, indices = prepare_local_vision_inputs( | ||
| pixel_values, grid_thw, image_assignments, dp_rank=0 | ||
| ) | ||
| assert pix.shape[0] == 100 | ||
| assert grid.shape[0] == 2 | ||
| assert indices == [0, 2] | ||
| expected = torch.cat([pixel_values[0:50], pixel_values[100:150]], dim=0) | ||
| assert torch.allclose(pix, expected) | ||
|
|
||
| def test_empty_rank(self): | ||
| pixel_values = torch.randn(100, 768) | ||
| grid_thw = torch.tensor([[1, 10, 10]]) | ||
| image_assignments = [[0], []] | ||
|
|
||
| pix, grid, indices = prepare_local_vision_inputs( | ||
| pixel_values, grid_thw, image_assignments, dp_rank=1 | ||
| ) | ||
| assert pix.shape[0] == 0 | ||
| assert grid.shape[0] == 0 | ||
| assert indices == [] | ||
|
|
||
| def test_grid_thw_preserved(self): | ||
| pixel_values = torch.randn(150, 768) | ||
| grid_thw = torch.tensor([[1, 5, 5], [2, 5, 5], [3, 5, 5]]) # 25 + 50 + 75 | ||
| image_assignments = [[0, 2], [1]] | ||
|
|
||
| _, local_grid, _ = prepare_local_vision_inputs( | ||
| pixel_values, grid_thw, image_assignments, dp_rank=0 | ||
| ) | ||
| assert local_grid.shape == (2, 3) | ||
| assert torch.equal(local_grid[0], grid_thw[0]) | ||
| assert torch.equal(local_grid[1], grid_thw[2]) | ||
|
|
||
|
|
||
| class TestIntegration: | ||
| def test_full_workflow(self): | ||
| grid_thw = torch.tensor([[1, 4, 4], [1, 8, 8], [1, 4, 4], [1, 6, 6], [1, 4, 4]]) | ||
| total_patches = 16 + 64 + 16 + 36 + 16 # 148 | ||
| pixel_values = torch.randn(total_patches, 768) | ||
|
|
||
| patch_counts = get_image_patch_counts(grid_thw) | ||
| assert patch_counts == [16, 64, 16, 36, 16] | ||
|
|
||
| assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=2) | ||
| all_assigned = [] | ||
| for a in assignments: | ||
| all_assigned.extend(a) | ||
| assert sorted(all_assigned) == [0, 1, 2, 3, 4] | ||
|
|
||
| total_local_patches = 0 | ||
| for rank in range(2): | ||
| pix, grid, indices = prepare_local_vision_inputs( | ||
| pixel_values, grid_thw, assignments, dp_rank=rank | ||
| ) | ||
| expected = sum(patch_counts[i] for i in indices) | ||
| assert pix.shape[0] == expected | ||
| assert grid.shape[0] == len(indices) | ||
| total_local_patches += pix.shape[0] | ||
|
|
||
| assert total_local_patches == total_patches | ||
|
|
||
| def test_same_size_images_4_ranks(self): | ||
| num_images = 50 | ||
| grid_thw = torch.tensor([[1, 8, 8]] * num_images) | ||
| patch_counts = get_image_patch_counts(grid_thw) | ||
| assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=4) | ||
|
|
||
| for rank in range(4): | ||
| assert 12 <= len(assignments[rank]) <= 13 | ||
| for load in loads: | ||
| assert load in [768, 832] | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.