Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions areal/models/transformers/ulyssess_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,11 @@ def apply_monkey_patch(

patch_vlm_for_ulysses_input_slicing(model_class)
logger.info(f"Patched {model_class_name}.forward")

# Apply Vision DP: distribute ViT computation across SP ranks
from areal.utils.vision_dp import apply_vision_dp_patch

apply_vision_dp_patch()
else:
from transformers.integrations import flash_attention

Expand Down
185 changes: 185 additions & 0 deletions areal/tests/test_vision_dp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
"""
Unit tests for Vision Data Parallel utilities (CPU-only, no distributed).

Adapted from verl PR #5230 tests.
"""

import pytest
import torch

from areal.utils.vision_dp import (
assign_images_to_dp_ranks,
get_image_embedding_counts,
get_image_patch_counts,
prepare_local_vision_inputs,
)


class TestGetImagePatchCounts:
@pytest.mark.parametrize(
"grid_thw,expected",
[
([[2, 4, 4], [1, 2, 2], [1, 8, 8]], [32, 4, 64]),
([[1, 4, 4]], [16]),
([[4, 4, 4]], [64]),
],
ids=["multi-image", "single-image", "video-frames"],
)
def test_patch_counts(self, grid_thw, expected):
counts = get_image_patch_counts(torch.tensor(grid_thw))
assert counts == expected

def test_empty_input(self):
counts = get_image_patch_counts(torch.empty((0, 3), dtype=torch.long))
assert counts == []


class TestGetImageEmbeddingCounts:
@pytest.mark.parametrize(
"grid_thw,merge_size,expected",
[
([[1, 8, 8]], 1, [64]),
([[1, 8, 8]], 2, [16]),
([[1, 6, 6], [1, 4, 4]], 2, [9, 4]),
],
ids=["no-merge", "merge-2", "multi-image-merge"],
)
def test_embedding_counts(self, grid_thw, merge_size, expected):
counts = get_image_embedding_counts(torch.tensor(grid_thw), merge_size)
assert counts == expected


class TestAssignImagesToDpRanks:
@pytest.mark.parametrize(
"patch_counts,dp_size,expected_lens,expected_loads",
[
([100, 100, 100, 100], 2, [2, 2], [200, 200]),
([100, 200, 300], 1, [3], [600]),
([100, 100, 100, 100, 100, 100], 3, [2, 2, 2], [200, 200, 200]),
],
ids=["balanced-2ranks", "single-rank", "balanced-3ranks"],
)
def test_balanced(self, patch_counts, dp_size, expected_lens, expected_loads):
assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size)
assert [len(a) for a in assignments] == expected_lens
assert loads == expected_loads

def test_fewer_images_than_ranks(self):
assignments, loads = assign_images_to_dp_ranks([100, 200], dp_size=4)
non_empty = sum(1 for a in assignments if len(a) > 0)
assert non_empty == 2
all_assigned = set()
for a in assignments:
all_assigned.update(a)
assert all_assigned == {0, 1}

def test_empty_input(self):
assignments, loads = assign_images_to_dp_ranks([], dp_size=4)
assert all(len(a) == 0 for a in assignments)
assert all(load == 0 for load in loads)

def test_image_order_preserved(self):
assignments, _ = assign_images_to_dp_ranks([10, 20, 30, 40, 50], dp_size=2)
for rank_assignment in assignments:
assert rank_assignment == sorted(rank_assignment)


class TestPrepareLocalVisionInputs:
def test_basic_extraction(self):
pixel_values = torch.randn(100, 768)
grid_thw = torch.tensor([[1, 6, 6], [1, 8, 8]]) # 36 + 64 = 100
image_assignments = [[0], [1]]

# Rank 0
pix, grid, indices = prepare_local_vision_inputs(
pixel_values, grid_thw, image_assignments, dp_rank=0
)
assert pix.shape[0] == 36
assert grid.shape[0] == 1
assert indices == [0]
assert torch.allclose(pix, pixel_values[:36])

# Rank 1
pix, grid, indices = prepare_local_vision_inputs(
pixel_values, grid_thw, image_assignments, dp_rank=1
)
assert pix.shape[0] == 64
assert indices == [1]
assert torch.allclose(pix, pixel_values[36:100])

def test_multiple_images_per_rank(self):
pixel_values = torch.randn(200, 768)
grid_thw = torch.tensor([[1, 5, 10]] * 4) # 4 x 50 patches
image_assignments = [[0, 2], [1, 3]]

pix, grid, indices = prepare_local_vision_inputs(
pixel_values, grid_thw, image_assignments, dp_rank=0
)
assert pix.shape[0] == 100
assert grid.shape[0] == 2
assert indices == [0, 2]
expected = torch.cat([pixel_values[0:50], pixel_values[100:150]], dim=0)
assert torch.allclose(pix, expected)

def test_empty_rank(self):
pixel_values = torch.randn(100, 768)
grid_thw = torch.tensor([[1, 10, 10]])
image_assignments = [[0], []]

pix, grid, indices = prepare_local_vision_inputs(
pixel_values, grid_thw, image_assignments, dp_rank=1
)
assert pix.shape[0] == 0
assert grid.shape[0] == 0
assert indices == []

def test_grid_thw_preserved(self):
pixel_values = torch.randn(150, 768)
grid_thw = torch.tensor([[1, 5, 5], [2, 5, 5], [3, 5, 5]]) # 25 + 50 + 75
image_assignments = [[0, 2], [1]]

_, local_grid, _ = prepare_local_vision_inputs(
pixel_values, grid_thw, image_assignments, dp_rank=0
)
assert local_grid.shape == (2, 3)
assert torch.equal(local_grid[0], grid_thw[0])
assert torch.equal(local_grid[1], grid_thw[2])


class TestIntegration:
def test_full_workflow(self):
grid_thw = torch.tensor([[1, 4, 4], [1, 8, 8], [1, 4, 4], [1, 6, 6], [1, 4, 4]])
total_patches = 16 + 64 + 16 + 36 + 16 # 148
pixel_values = torch.randn(total_patches, 768)

patch_counts = get_image_patch_counts(grid_thw)
assert patch_counts == [16, 64, 16, 36, 16]

assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=2)
all_assigned = []
for a in assignments:
all_assigned.extend(a)
assert sorted(all_assigned) == [0, 1, 2, 3, 4]

total_local_patches = 0
for rank in range(2):
pix, grid, indices = prepare_local_vision_inputs(
pixel_values, grid_thw, assignments, dp_rank=rank
)
expected = sum(patch_counts[i] for i in indices)
assert pix.shape[0] == expected
assert grid.shape[0] == len(indices)
total_local_patches += pix.shape[0]

assert total_local_patches == total_patches

def test_same_size_images_4_ranks(self):
num_images = 50
grid_thw = torch.tensor([[1, 8, 8]] * num_images)
patch_counts = get_image_patch_counts(grid_thw)
assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=4)

for rank in range(4):
assert 12 <= len(assignments[rank]) <= 13
for load in loads:
assert load in [768, 832]
Loading