diff --git a/tasks/omni/train_qwen_vl.py b/tasks/omni/train_qwen_vl.py index 0544651a5..62e36de60 100644 --- a/tasks/omni/train_qwen_vl.py +++ b/tasks/omni/train_qwen_vl.py @@ -129,6 +129,7 @@ def main(): attn_implementation=args.model.attn_implementation, encoder_data_balance=args.model.encoder_data_balance, encoder_data_balance_sorting_algo=args.model.encoder_data_balance_sorting_algo, + vision_dp=args.train.vision_dp, ) model_config = model.config helper.print_device_mem_info("VRAM usage after building model") diff --git a/tasks/train_torch.py b/tasks/train_torch.py index e6b539eb3..b853c42e6 100644 --- a/tasks/train_torch.py +++ b/tasks/train_torch.py @@ -150,6 +150,7 @@ def main(): attn_implementation=args.model.attn_implementation, moe_implementation=args.model.moe_implementation, init_device=args.train.init_device, + vision_dp=args.train.vision_dp, ) model_config = model.config helper.print_device_mem_info("VRAM usage after building model") diff --git a/tests/parallel/conftest.py b/tests/parallel/conftest.py new file mode 100644 index 000000000..07de5a574 --- /dev/null +++ b/tests/parallel/conftest.py @@ -0,0 +1,95 @@ +"""Stub heavy dependencies to allow CPU-only testing of vision_dp utilities. + +The veomni import chain pulls in datasets, flash_attn, CUDA ops, etc. +We stub everything except the specific module under test (vision_dp.py) +so that pytest can collect and run the tests on a CPU-only machine. +""" + +import importlib +import sys +import types + + +def _ensure_stub(name, **attrs): + """Create a stub module with __path__ if it doesn't already exist.""" + if name in sys.modules: + mod = sys.modules[name] + else: + mod = types.ModuleType(name) + mod.__path__ = [name.replace(".", "/")] + sys.modules[name] = mod + for k, v in attrs.items(): + setattr(mod, k, v) + return mod + + +# ── Stub veomni top-level (prevent __init__.py from importing ops/data) ── +_ensure_stub("veomni") + +_ensure_stub("veomni.ops") +_ensure_stub("veomni.data") +_ensure_stub("veomni.data.constants", IGNORE_INDEX=-100) + +# ── Stub utils ── +_ensure_stub("veomni.utils") + + +class _FakeLogger: + def __getattr__(self, name): + return lambda *a, **kw: None + + +_ensure_stub("veomni.utils.logging", get_logger=lambda name=None: _FakeLogger()) +_ensure_stub( + "veomni.utils.device", + get_device_type=lambda: "cpu", + get_device_id=lambda: "cpu", + IS_NPU_AVAILABLE=False, + IS_CUDA_AVAILABLE=False, +) +_ensure_stub("veomni.utils.import_utils", is_torch_version_greater_than=lambda v: True) + +# ── Stub distributed ── +_ensure_stub("veomni.distributed") + + +class _FakeParallelState: + sp_enabled = False + sp_size = 1 + sp_rank = 0 + sp_group = None + + +_ensure_stub( + "veomni.distributed.parallel_state", + get_parallel_state=lambda: _FakeParallelState(), + ParallelState=_FakeParallelState, +) + +# ── Stub the sequence_parallel __init__ and its heavy sub-modules ── +# We need to prevent the real __init__.py from running (it imports +# async_ulysses, comm, data, loss, ulysses, utils which have heavy deps). +# So we register the package stub FIRST, then load vision_dp.py directly. +_sp_pkg = _ensure_stub("veomni.distributed.sequence_parallel") + +for _sub in [ + "veomni.distributed.sequence_parallel.async_ulysses", + "veomni.distributed.sequence_parallel.comm", + "veomni.distributed.sequence_parallel.data", + "veomni.distributed.sequence_parallel.loss", + "veomni.distributed.sequence_parallel.ulysses", + "veomni.distributed.sequence_parallel.utils", +]: + _ensure_stub(_sub) + +# Now load vision_dp.py for real (it only depends on torch, dist, parallel_state) +_vision_dp_spec = importlib.util.spec_from_file_location( + "veomni.distributed.sequence_parallel.vision_dp", + "veomni/distributed/sequence_parallel/vision_dp.py", +) +_vision_dp_mod = importlib.util.module_from_spec(_vision_dp_spec) +sys.modules["veomni.distributed.sequence_parallel.vision_dp"] = _vision_dp_mod +_vision_dp_spec.loader.exec_module(_vision_dp_mod) + +# Attach it to the parent package +_sp_pkg.vision_dp = _vision_dp_mod diff --git a/tests/parallel/test_vision_dp.py b/tests/parallel/test_vision_dp.py new file mode 100644 index 000000000..2c8f75e8c --- /dev/null +++ b/tests/parallel/test_vision_dp.py @@ -0,0 +1,205 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Unit tests for Vision Data Parallel utilities (CPU-only, no distributed). +""" + +import pytest +import torch + +from veomni.distributed.sequence_parallel.vision_dp import ( + assign_images_to_dp_ranks, + gather_vision_embeddings, + get_image_embedding_counts, + get_image_patch_counts, + prepare_local_vision_inputs, +) + + +class TestGetImagePatchCounts: + @pytest.mark.parametrize( + "grid_thw,expected", + [ + ([[2, 4, 4], [1, 2, 2], [1, 8, 8]], [32, 4, 64]), + ([[1, 4, 4]], [16]), + ([[4, 4, 4]], [64]), + ], + ids=["multi-image", "single-image", "video-frames"], + ) + def test_patch_counts_various_grids_correct_products(self, grid_thw, expected): + counts = get_image_patch_counts(torch.tensor(grid_thw)) + assert counts == expected + + def test_patch_counts_empty_input_returns_empty_list(self): + counts = get_image_patch_counts(torch.empty((0, 3), dtype=torch.long)) + assert counts == [] + + +class TestGetImageEmbeddingCounts: + @pytest.mark.parametrize( + "grid_thw,merge_size,expected", + [ + ([[1, 8, 8]], 1, [64]), + ([[1, 8, 8]], 2, [16]), + ([[1, 6, 6], [1, 4, 4]], 2, [9, 4]), + ], + ids=["no-merge", "merge-2", "multi-image-merge"], + ) + def test_embedding_counts_with_merge_size_correct(self, grid_thw, merge_size, expected): + counts = get_image_embedding_counts(torch.tensor(grid_thw), merge_size) + assert counts == expected + + +class TestAssignImagesToDpRanks: + @pytest.mark.parametrize( + "patch_counts,dp_size", + [ + ([100, 100, 100, 100], 2), + ([100, 200, 300], 1), + ([100, 100, 100, 100, 100, 100], 3), + ], + ids=["balanced-2ranks", "single-rank", "balanced-3ranks"], + ) + def test_assign_all_images_distributed(self, patch_counts, dp_size): + assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size) + all_assigned = [] + for a in assignments: + all_assigned.extend(a) + assert sorted(all_assigned) == list(range(len(patch_counts))) + assert sum(loads) == sum(patch_counts) + + def test_assign_fewer_images_than_ranks_all_assigned(self): + assignments, loads = assign_images_to_dp_ranks([100, 200], dp_size=4) + non_empty = sum(1 for a in assignments if len(a) > 0) + assert non_empty == 2 + all_assigned = set() + for a in assignments: + all_assigned.update(a) + assert all_assigned == {0, 1} + + def test_assign_empty_input_returns_empty(self): + assignments, loads = assign_images_to_dp_ranks([], dp_size=4) + assert all(len(a) == 0 for a in assignments) + assert all(load == 0 for load in loads) + + def test_assign_image_order_preserved_contiguous(self): + assignments, _ = assign_images_to_dp_ranks([10, 20, 30, 40, 50], dp_size=2) + for rank_assignment in assignments: + assert rank_assignment == sorted(rank_assignment) + + def test_assign_load_balanced_unequal_patches(self): + """With unequal patch counts, greedy balancing should reduce imbalance.""" + patch_counts = [4096, 256, 256, 256] + assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=2) + all_assigned = [] + for a in assignments: + all_assigned.extend(a) + assert sorted(all_assigned) == [0, 1, 2, 3] + max_load = max(loads) + min_load = min(load for load in loads if load > 0) + assert max_load / min_load < 8.0 + + +class TestPrepareLocalVisionInputs: + def test_prepare_two_images_splits_correctly(self): + pixel_values = torch.randn(100, 768) + grid_thw = torch.tensor([[1, 6, 6], [1, 8, 8]]) # 36 + 64 = 100 + image_assignments = [[0], [1]] + + pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=0) + assert pix.shape[0] == 36 + assert grid.shape[0] == 1 + assert indices == [0] + assert torch.allclose(pix, pixel_values[:36]) + + pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=1) + assert pix.shape[0] == 64 + assert grid.shape[0] == 1 + assert indices == [1] + assert torch.allclose(pix, pixel_values[36:100]) + + def test_prepare_multiple_contiguous_images_per_rank(self): + pixel_values = torch.randn(200, 768) + grid_thw = torch.tensor([[1, 5, 10]] * 4) # 4 x 50 patches + image_assignments = [[0, 1], [2, 3]] + + pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=0) + assert pix.shape[0] == 100 + assert grid.shape[0] == 2 + assert indices == [0, 1] + assert torch.allclose(pix, pixel_values[:100]) + + def test_prepare_empty_rank_returns_empty(self): + pixel_values = torch.randn(100, 768) + grid_thw = torch.tensor([[1, 10, 10]]) + image_assignments = [[0], []] + + pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=1) + assert pix.shape[0] == 0 + assert grid.shape[0] == 0 + assert indices == [] + + def test_prepare_grid_thw_preserved(self): + pixel_values = torch.randn(150, 768) + grid_thw = torch.tensor([[1, 5, 5], [2, 5, 5], [3, 5, 5]]) # 25 + 50 + 75 + image_assignments = [[0, 1], [2]] + + _, local_grid, _ = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=0) + assert local_grid.shape == (2, 3) + assert torch.equal(local_grid[0], grid_thw[0]) + assert torch.equal(local_grid[1], grid_thw[1]) + + +class TestGatherVisionEmbeddings: + def test_gather_none_group_returns_input(self): + embeddings = torch.randn(10, 64) + result = gather_vision_embeddings(embeddings, dp_group=None, all_counts=[10]) + assert torch.equal(result, embeddings) + + +class TestIntegration: + def test_full_workflow_all_patches_covered(self): + grid_thw = torch.tensor([[1, 4, 4], [1, 8, 8], [1, 4, 4], [1, 6, 6], [1, 4, 4]]) + total_patches = 16 + 64 + 16 + 36 + 16 # 148 + pixel_values = torch.randn(total_patches, 768) + + patch_counts = get_image_patch_counts(grid_thw) + assert patch_counts == [16, 64, 16, 36, 16] + + assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=2) + all_assigned = [] + for a in assignments: + all_assigned.extend(a) + assert sorted(all_assigned) == [0, 1, 2, 3, 4] + + total_local_patches = 0 + for rank in range(2): + pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, assignments, dp_rank=rank) + expected = sum(patch_counts[i] for i in indices) + assert pix.shape[0] == expected + assert grid.shape[0] == len(indices) + total_local_patches += pix.shape[0] + + assert total_local_patches == total_patches + + def test_same_size_images_4_ranks_balanced(self): + num_images = 50 + grid_thw = torch.tensor([[1, 8, 8]] * num_images) + patch_counts = get_image_patch_counts(grid_thw) + assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=4) + + for rank in range(4): + assert 12 <= len(assignments[rank]) <= 13 + for load in loads: + assert load in [768, 832] diff --git a/veomni/arguments/arguments_types.py b/veomni/arguments/arguments_types.py index b41ab3a03..6e865983a 100644 --- a/veomni/arguments/arguments_types.py +++ b/veomni/arguments/arguments_types.py @@ -554,6 +554,10 @@ class TrainingArguments: default=1, metadata={"help": "Ulysses sequence parallel size."}, ) + vision_dp: bool = field( + default=False, + metadata={"help": "Enable Vision DP: distribute ViT across Ulysses SP ranks."}, + ) context_parallel_size: int = field( default=1, metadata={"help": "Ring-attn context parallel size."}, diff --git a/veomni/distributed/sequence_parallel/__init__.py b/veomni/distributed/sequence_parallel/__init__.py index b2706be64..c9449f0eb 100644 --- a/veomni/distributed/sequence_parallel/__init__.py +++ b/veomni/distributed/sequence_parallel/__init__.py @@ -52,6 +52,14 @@ gather_seq_scatter_heads, ) from .utils import pad_tensor, unpad_tensor, vlm_images_a2a_meta +from .vision_dp import ( + assign_images_to_dp_ranks, + create_dp_vision_forward, + gather_vision_embeddings, + get_image_embedding_counts, + get_image_patch_counts, + prepare_local_vision_inputs, +) __all__ = [ @@ -88,4 +96,10 @@ "async_ulysses_output_projection", "divide_qkv_linear_weight", "divide_qkv_linear_bias", + "get_image_patch_counts", + "get_image_embedding_counts", + "assign_images_to_dp_ranks", + "prepare_local_vision_inputs", + "gather_vision_embeddings", + "create_dp_vision_forward", ] diff --git a/veomni/distributed/sequence_parallel/vision_dp.py b/veomni/distributed/sequence_parallel/vision_dp.py new file mode 100644 index 000000000..24fe87da1 --- /dev/null +++ b/veomni/distributed/sequence_parallel/vision_dp.py @@ -0,0 +1,335 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Vision Data Parallel utilities for VeOmni. + +Distribute whole images across SP ranks, not patches within images. +Each rank runs ViT on its assigned images, then all-gather combines embeddings. +Backward all_reduce(SUM) recovers complete gradients before slicing by assignment. +""" + +import torch +import torch.distributed as dist +from torch.autograd import Function + +from ...distributed.parallel_state import get_parallel_state +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +def get_image_patch_counts(grid_thw: torch.Tensor) -> list[int]: + """Return [t*h*w for each image] from a [num_images, 3] grid_thw tensor.""" + if grid_thw.numel() == 0: + return [] + return (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).tolist() + + +def get_image_embedding_counts(grid_thw: torch.Tensor, spatial_merge_size: int = 1) -> list[int]: + """Return per-image embedding counts after spatial merging: t * (h/merge) * (w/merge).""" + if grid_thw.numel() == 0: + return [] + if spatial_merge_size == 1: + return get_image_patch_counts(grid_thw) + t = grid_thw[:, 0] + h = grid_thw[:, 1] // spatial_merge_size + w = grid_thw[:, 2] // spatial_merge_size + return (t * h * w).tolist() + + +def assign_images_to_dp_ranks( + patch_counts: list[int], + dp_size: int, +) -> tuple[list[list[int]], list[int]]: + """Assign whole images to DP ranks via greedy contiguous bin-packing. + + Returns (image_assignments, rank_patch_counts). Images are kept contiguous + so the gather result needs no reordering. + """ + num_images = len(patch_counts) + if num_images == 0: + return [[] for _ in range(dp_size)], [0] * dp_size + + image_assignments: list[list[int]] = [[] for _ in range(dp_size)] + rank_loads = [0] * dp_size + + remaining_patches = sum(patch_counts) + img_idx = 0 + for rank in range(dp_size): + remaining_ranks = dp_size - rank + remaining_images = num_images - img_idx + + if remaining_images <= 0: + break + + # Dynamic target: distribute remaining patches evenly among remaining ranks + target = remaining_patches / remaining_ranks + + # Must leave at least 1 image for each remaining rank + max_images = remaining_images - (remaining_ranks - 1) + + # Greedily add images until we reach the target load or hit the max + count = 0 + while img_idx < num_images and count < max_images: + image_assignments[rank].append(img_idx) + rank_loads[rank] += patch_counts[img_idx] + img_idx += 1 + count += 1 + + # Stop early once we've reached the target (always take at least 1) + if rank_loads[rank] >= target: + break + + remaining_patches -= rank_loads[rank] + + return image_assignments, rank_loads + + +def prepare_local_vision_inputs( + pixel_values: torch.Tensor, + grid_thw: torch.Tensor, + image_assignments: list[list[int]], + dp_rank: int, + patch_counts: list[int] | None = None, +) -> tuple[torch.Tensor, torch.Tensor, list[int]]: + """Extract pixel values and grid_thw for this DP rank's assigned images. + + Exploits contiguous assignment: a single slice instead of per-image cat. + """ + local_indices = image_assignments[dp_rank] + + if len(local_indices) == 0: + return ( + torch.empty( + (0, pixel_values.shape[1]) if pixel_values.dim() > 1 else (0,), + dtype=pixel_values.dtype, + device=pixel_values.device, + ), + torch.empty((0, 3), dtype=grid_thw.dtype, device=grid_thw.device), + [], + ) + + # local_indices are contiguous (e.g. [2, 3, 4]), so use tensor slicing + first_img_idx = local_indices[0] + last_img_idx = local_indices[-1] + + # Use pre-computed patch_counts to avoid redundant GPU→CPU transfer + if patch_counts is None: + patch_counts = get_image_patch_counts(grid_thw) + patch_counts_tensor = torch.tensor(patch_counts, device=grid_thw.device, dtype=torch.long) + offsets = torch.cat( + ( + torch.tensor([0], device=grid_thw.device, dtype=torch.long), + torch.cumsum(patch_counts_tensor, dim=0), + ) + ) + + start_patch = offsets[first_img_idx].item() + end_patch = offsets[last_img_idx + 1].item() + + local_pixel_values = pixel_values[start_patch:end_patch] + local_grid_thw = grid_thw[first_img_idx : last_img_idx + 1] + + expected_patches = end_patch - start_patch + assert local_pixel_values.shape[0] == expected_patches, ( + f"[Vision DP] Local patch count mismatch: " + f"extracted={local_pixel_values.shape[0]}, expected={expected_patches}, " + f"local_indices={local_indices}" + ) + + return local_pixel_values, local_grid_thw, local_indices + + +class GatherVisionEmbeddings(Function): + """All-gather vision embeddings with gradient support. + + Contiguous assignment means simple concat without reordering. + Backward: all_reduce(SUM) to aggregate gradients from all sequence shards, + then slice to extract this rank's image gradients. + """ + + @staticmethod + def forward(ctx, local_embeddings, dp_group, all_counts: list[int]): + dp_size = dist.get_world_size(dp_group) + dp_rank = dist.get_rank(dp_group) + ctx.dp_size = dp_size + ctx.dp_group = dp_group + ctx.all_counts = all_counts + ctx.dp_rank = dp_rank + + if dp_size == 1: + return local_embeddings + + max_count = max(all_counts) if all_counts else 0 + if max_count == 0: + return local_embeddings + + hidden_size = local_embeddings.shape[1] if local_embeddings.dim() > 1 else 1 + + if local_embeddings.shape[0] < max_count: + pad_size = max_count - local_embeddings.shape[0] + padding = torch.zeros( + (pad_size, hidden_size), + dtype=local_embeddings.dtype, + device=local_embeddings.device, + ) + local_padded = torch.cat([local_embeddings, padding], dim=0) + else: + local_padded = local_embeddings + + gathered = [torch.empty_like(local_padded) for _ in range(dp_size)] + dist.all_gather(gathered, local_padded, group=dp_group) + + result_chunks = [gathered[r][: all_counts[r]] for r in range(dp_size)] + result = torch.cat(result_chunks, dim=0) + return result + + @staticmethod + def backward(ctx, grad_output): + dp_size = ctx.dp_size + + if dp_size == 1: + return grad_output, None, None + + all_counts = ctx.all_counts + dp_rank = ctx.dp_rank + dp_group = ctx.dp_group + + # all_reduce(SUM) aggregates partial gradients from all SP ranks: + # each rank only has non-zero grad for vision tokens in its sequence shard. + # NCCL all_reduce requires contiguous tensors — defensive guard. + grad = grad_output.contiguous() + dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=dp_group) + + start = sum(all_counts[:dp_rank]) + end = start + all_counts[dp_rank] + local_grad = grad[start:end] + return local_grad, None, None + + +def gather_vision_embeddings(local_embeddings, dp_group, all_counts: list[int]): + """All-gather vision embeddings from all DP ranks with gradient support.""" + if dp_group is None or dist.get_world_size(dp_group) == 1: + return local_embeddings + return GatherVisionEmbeddings.apply(local_embeddings, dp_group, all_counts) + + +def create_dp_vision_forward(original_forward): + """Wrap VisionTransformer.forward for Vision DP (Data Parallel across SP ranks). + + Strategy: + 1. Distribute whole images to SP ranks (not patches within images) + 2. Each rank processes its assigned images independently + 3. All-gather embeddings at the end (contiguous assignment, no reordering) + + Passes _vision_dp=True so the inner ViT can skip its own patch-level SP logic. + + Gradient correctness: after all-gather in forward, each SP rank's inputs_embeds + contains vision tokens from ALL images. But Ulysses gives each rank only its + sequence shard. In backward, each rank only has non-zero gradient for vision + tokens in its own shard. The all_reduce(SUM) in GatherVisionEmbeddings.backward + aggregates partial gradients from all ranks, recovering the complete gradient. + """ + + def dp_vision_forward(self, hidden_states, grid_thw, **kwargs): + ps = get_parallel_state() + dp_size = ps.sp_size if ps.sp_enabled else 1 + if dp_size <= 1: + return original_forward(self, hidden_states, grid_thw, _vision_dp=True, **kwargs) + + dp_group = ps.sp_group + dp_rank = ps.sp_rank + + # Move grid_thw to CPU once to avoid repeated GPU->CPU syncs + grid_thw_cpu = grid_thw.cpu() + + # Step 1: Get image assignment + patch_counts = get_image_patch_counts(grid_thw_cpu) + total_patches = sum(patch_counts) + assert hidden_states.shape[0] == total_patches + + spatial_merge_size = 1 + if hasattr(self, "merger") and hasattr(self.merger, "spatial_merge_size"): + spatial_merge_size = self.merger.spatial_merge_size + elif hasattr(self, "spatial_merge_size"): + spatial_merge_size = self.spatial_merge_size + + embedding_counts = get_image_embedding_counts(grid_thw_cpu, spatial_merge_size) + total_embeddings = sum(embedding_counts) + + image_assignments, _ = assign_images_to_dp_ranks(patch_counts, dp_size) + + # Step 2: Extract local inputs (pass pre-computed patch_counts to avoid GPU→CPU sync) + local_pixels, local_grid_thw, local_indices = prepare_local_vision_inputs( + hidden_states, grid_thw, image_assignments, dp_rank, patch_counts=patch_counts + ) + + # Detect Qwen3-VL deepstack: model attribute, not return type, + # because empty ranks don't call original_forward and can't inspect the return. + has_deepstack = hasattr(self, "deepstack_merger_list") + + # Step 3: Process local images (pass _vision_dp=True to skip SP patches) + if local_pixels.shape[0] > 0: + local_embeddings = original_forward(self, local_pixels, local_grid_thw, _vision_dp=True, **kwargs) + else: + # This rank has no images, create empty tensor with correct hidden size + hidden_size = getattr(getattr(self, "config", None), "out_hidden_size", None) + if hidden_size is None: + raise RuntimeError( + f"Cannot determine hidden_size: self.config.out_hidden_size not found. " + f"Model type: {type(self).__name__}" + ) + + local_embeddings = torch.empty( + (0, hidden_size), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + # Empty rank must participate in autograd for backward all_reduce + local_embeddings.requires_grad_() + + # Unpack Qwen3-VL deepstack: forward returns (embeddings, list[3 × Tensor]) + local_deepstack = None + if has_deepstack: + if isinstance(local_embeddings, tuple): + local_embeddings, local_deepstack = local_embeddings[0], local_embeddings[1] + else: + # Empty rank: create matching empty deepstack tensors + num_deepstack = len(self.deepstack_merger_list) + h = local_embeddings.shape[1] + local_deepstack = [ + torch.empty( + (0, h), dtype=hidden_states.dtype, device=hidden_states.device + ) + for _ in range(num_deepstack) + ] + + # Step 4: All-gather + # Compute per-rank embedding counts locally (grid_thw is replicated on all ranks) + all_counts = [sum(embedding_counts[i] for i in image_assignments[r]) for r in range(dp_size)] + all_embeddings = gather_vision_embeddings(local_embeddings, dp_group, all_counts) + assert all_embeddings.shape[0] == total_embeddings + + # Step 5: All-gather deepstack embeddings (all ranks must participate) + if local_deepstack is not None: + gathered_deepstack = [ + gather_vision_embeddings(ds, dp_group, all_counts) + for ds in local_deepstack + ] + return all_embeddings, gathered_deepstack + + return all_embeddings + + return dp_vision_forward diff --git a/veomni/models/auto.py b/veomni/models/auto.py index 599d3c52e..01197efde 100644 --- a/veomni/models/auto.py +++ b/veomni/models/auto.py @@ -80,6 +80,7 @@ def build_foundation_model( config_kwargs: Optional[Dict[str, Any]] = None, encoder_data_balance: Optional[bool] = False, encoder_data_balance_sorting_algo: Optional[str] = "post_mbs_balancing_greedy_without_pad", + vision_dp: bool = False, ) -> "PreTrainedModel": """ Builds the foundation model. @@ -186,6 +187,20 @@ def wrapped_forward(*args, **kwargs): "to avoid unexpected kernel loading side effects." ) + if vision_dp and get_parallel_state().sp_enabled: + if config.model_type == "qwen2_5_vl": + from .transformers.qwen2_5vl.modeling_qwen2_5_vl import apply_vision_dp_patch_qwen25 + + apply_vision_dp_patch_qwen25() + logger.info_rank0("Applied Vision DP patch for Qwen2.5-VL") + elif config.model_type == "qwen3_vl": + from .transformers.qwen3_vl.modeling_qwen3_vl import apply_vision_dp_patch + + apply_vision_dp_patch() + logger.info_rank0("Applied Vision DP patch for Qwen3-VL") + else: + logger.warning_rank0(f"Vision DP requested but model_type={config.model_type} is not supported") + model_class_path = f"{model.__class__.__module__}.{model.__class__.__name__}" logger.info_rank0(f"Built foundation model class: {model_class_path}") diff --git a/veomni/models/transformers/qwen2_5vl/modeling_qwen2_5_vl.py b/veomni/models/transformers/qwen2_5vl/modeling_qwen2_5_vl.py index c2a843f76..6b827c340 100644 --- a/veomni/models/transformers/qwen2_5vl/modeling_qwen2_5_vl.py +++ b/veomni/models/transformers/qwen2_5vl/modeling_qwen2_5_vl.py @@ -149,6 +149,7 @@ def Qwen2_5_VisionTransformerPretrainedModel_forward( self: Qwen2_5_VisionTransformerPretrainedModel, hidden_states: torch.Tensor, grid_thw: torch.Tensor, + _vision_dp: bool = False, **kwargs, ) -> torch.Tensor: """ @@ -157,10 +158,16 @@ def Qwen2_5_VisionTransformerPretrainedModel_forward( The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. + _vision_dp: When True, skip patch-level SP logic (Vision DP handles + distribution at the image level instead). Returns: `torch.Tensor`: hidden_states. """ + # When _vision_dp=True, this forward receives only this rank's assigned + # images (complete, not split). Skip all SP pad/slice/all-to-all operations. + use_sp = get_parallel_state().sp_enabled and not _vision_dp + hidden_states = self.patch_embed(hidden_states) rotary_pos_emb = self.rot_pos_emb(grid_thw) window_index, cu_window_seqlens = self.get_window_index(grid_thw) @@ -183,7 +190,7 @@ def Qwen2_5_VisionTransformerPretrainedModel_forward( # --- Patch.1 --- unpadded_dim_size = cu_seqlens[-1] - if get_parallel_state().sp_enabled: + if use_sp: hidden_states = gather_seq_scatter_heads( hidden_states, seq_dim=0, head_dim=1, group=get_parallel_state().sp_group ) @@ -201,7 +208,7 @@ def Qwen2_5_VisionTransformerPretrainedModel_forward( rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) - if get_parallel_state().sp_enabled: + if use_sp: if sp_padding_size > 0: # --- Patch.1 --- hidden_states = pad_tensor(hidden_states, dim=0, padding_size=sp_padding_size) @@ -253,7 +260,7 @@ def Qwen2_5_VisionTransformerPretrainedModel_forward( reverse_indices = torch.argsort(window_index) # --- Patch.1 --- - if get_parallel_state().sp_enabled: + if use_sp: sp_padding_size = hidden_states.size(0) - unpadded_dim_size hidden_states = gather_seq_scatter_heads( hidden_states, seq_dim=0, head_dim=1, group=get_parallel_state().sp_group @@ -265,7 +272,7 @@ def Qwen2_5_VisionTransformerPretrainedModel_forward( hidden_states = hidden_states[reverse_indices, :] # --- Patch.1 --- - if get_parallel_state().sp_enabled: + if use_sp: if sp_padding_size > 0: hidden_states = pad_tensor(hidden_states, dim=0, padding_size=sp_padding_size) hidden_states = gather_heads_scatter_seq( @@ -283,7 +290,7 @@ def Qwen2_5_VisionTransformerPretrainedModel_forward( # ================================================================ # --- Patch.1 --- def Qwen2_5_VisionTransformerPretrainedModel_dummy_forward(self: Qwen2_5_VisionTransformerPretrainedModel): - if get_parallel_state().sp_enabled: + if get_parallel_state().sp_enabled and not getattr(self.__class__, "_vision_dp_patched", False): if getattr(self, "_sp_dummy_data", None) is None: sp_size = get_parallel_state().sp_size pixel_values = torch.randn((4, 3 * 2 * 14 * 14), dtype=self.dtype, device=self.device) @@ -426,10 +433,13 @@ def forward( ) # --- Patch.3 --- + _vision_dp_active = getattr(self.visual.__class__, "_vision_dp_patched", False) + if pixel_values is not None: image_embeds = self.get_image_features(pixel_values, image_grid_thw) # --- Patch.3 --- - if get_parallel_state().sp_enabled: + # When Vision DP is active, image_embeds are already all-gathered + if get_parallel_state().sp_enabled and not _vision_dp_active: image_embeds = gather_seq_scatter_heads( image_embeds, seq_dim=0, head_dim=-1, group=get_parallel_state().sp_group ) @@ -448,7 +458,8 @@ def forward( if pixel_values_videos is not None: video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) # --- Patch.3 --- - if get_parallel_state().sp_enabled: + # When Vision DP is active, video_embeds are already all-gathered + if get_parallel_state().sp_enabled and not _vision_dp_active: video_embeds = gather_seq_scatter_heads( video_embeds, seq_dim=0, head_dim=-1, group=get_parallel_state().sp_group ) @@ -692,7 +703,7 @@ def forward( ) -def apply_veomni_qwen25_vl_patch(): +def apply_veomni_qwen25_vl_patch(vision_dp: bool = False): logger.info_rank0("Apply VeOmni patch to Qwen2.5_VL.") hf_qwen25vl.Qwen2_5_VLVisionAttention.forward = Qwen2_5_VLVisionAttention_forward hf_qwen25vl.Qwen2_5_VisionTransformerPretrainedModel.forward = Qwen2_5_VisionTransformerPretrainedModel_forward @@ -701,3 +712,23 @@ def apply_veomni_qwen25_vl_patch(): ) hf_qwen25vl.Qwen2_5_VLModel = Qwen2_5_VLModel hf_qwen25vl.Qwen2_5_VLForConditionalGeneration = Qwen2_5_VLForConditionalGeneration + + if vision_dp: + apply_vision_dp_patch_qwen25() + + +def apply_vision_dp_patch_qwen25(): + """Apply Vision DP to Qwen2.5-VL VisionTransformer. + + Eliminates 4 all-to-all operations in the ViT forward by distributing + whole images instead of splitting patches across SP ranks. + """ + from ...distributed.sequence_parallel.vision_dp import create_dp_vision_forward + + cls = hf_qwen25vl.Qwen2_5_VisionTransformerPretrainedModel + if getattr(cls, "_vision_dp_patched", False): + return + original = cls.forward + cls.forward = create_dp_vision_forward(original) + cls._vision_dp_patched = True + logger.info_rank0("Applied Vision DP patch to Qwen2_5_VisionTransformerPretrainedModel.forward") diff --git a/veomni/models/transformers/qwen3_vl/modeling_qwen3_vl.py b/veomni/models/transformers/qwen3_vl/modeling_qwen3_vl.py index f917938e0..29bdc4b4a 100644 --- a/veomni/models/transformers/qwen3_vl/modeling_qwen3_vl.py +++ b/veomni/models/transformers/qwen3_vl/modeling_qwen3_vl.py @@ -401,7 +401,11 @@ def fast_pos_embed_interpolate(self, grid_thw): # 6. move cu_seqlens to cpu when using NPU to avoid per layer CPU-GPU sync when using FA # ================================================================ def Qwen3VLVisionModel_forward( - self: Qwen3VLVisionModel, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs + self: Qwen3VLVisionModel, + hidden_states: torch.Tensor, + grid_thw: torch.Tensor, + _vision_dp: bool = False, + **kwargs, ) -> torch.Tensor: """ Args: @@ -409,16 +413,22 @@ def Qwen3VLVisionModel_forward( The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. + _vision_dp: When True, skip patch-level SP logic (Vision DP handles + distribution at the image level instead). Returns: `torch.Tensor`: hidden_states. """ + # When _vision_dp=True, this forward receives only this rank's assigned + # images (complete, not split). Skip all SP pad/slice operations. + use_sp = get_parallel_state().sp_enabled and not _vision_dp + hidden_states = self.patch_embed(hidden_states) pos_embeds = self.fast_pos_embed_interpolate(grid_thw) # --- Patch.1 --- - if get_parallel_state().sp_enabled: + if use_sp: # We need to do padding here because of hidden_states did padding with pad_scale=4 pos_embeds = sp_pad_and_slice(pos_embeds, dim=0, pad_value=0, pad_scale=4) # --- Patch.1 --- @@ -448,7 +458,7 @@ def Qwen3VLVisionModel_forward( emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - if get_parallel_state().sp_enabled: + if use_sp: # --- Patch.3 --- cos, sin = position_embeddings cos = sp_pad_and_slice(cos, dim=0, pad_value=0, pad_scale=4) @@ -502,7 +512,7 @@ def Qwen3VLVisionModel_forward( # ================================================================ # --- Patch.1 --- def Qwen3VLVisionModel_dummy_forward(self): - if get_parallel_state().sp_enabled: + if get_parallel_state().sp_enabled and not getattr(self.__class__, "_vision_dp_patched", False): sp_size = get_parallel_state().sp_size pixel_values = torch.zeros((16, 3 * 2 * 16 * 16), dtype=self.dtype, device=self.device) # If using SP, pixel_values is sliced but grid_thw is not @@ -643,11 +653,15 @@ def forward( # Initialize fake_deepstack to None fake_deepstack = None + _vision_dp_active = getattr(self.visual.__class__, "_vision_dp_patched", False) + if pixel_values is not None: image_embeds, deepstack_image_embeds = self.get_image_features(pixel_values, image_grid_thw) # --- Patch.3 --- - if get_parallel_state().sp_enabled: + # When Vision DP is active, image_embeds are already all-gathered + # by the dp_vision_forward wrapper — skip the gather here. + if get_parallel_state().sp_enabled and not _vision_dp_active: # (seq_len // sp_size, hidden_size) to (seq_len, hidden_size // sp_size) image_embeds = gather_seq_scatter_heads( image_embeds, seq_dim=0, head_dim=-1, group=get_parallel_state().sp_group @@ -697,8 +711,8 @@ def forward( if pixel_values_videos is not None: # --- Patch.3 --- video_embeds, deepstack_video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) - # sequence parallel patch for video embeds - if get_parallel_state().sp_enabled: + # When Vision DP is active, video_embeds are already all-gathered + if get_parallel_state().sp_enabled and not _vision_dp_active: # (seq_len // sp_size, hidden_size) to (seq_len, hidden_size // sp_size) video_embeds = gather_seq_scatter_heads( video_embeds, seq_dim=0, head_dim=-1, group=get_parallel_state().sp_group @@ -967,7 +981,7 @@ def forward( ) -def apply_veomni_qwen3vl_patch(): +def apply_veomni_qwen3vl_patch(vision_dp: bool = False): logger.info_rank0("Apply VeOmni patch to Qwen3_VL.") hf_qwen3vl.Qwen3VLVisionAttention.forward = Qwen3VLVisionAttention_forward hf_qwen3vl.Qwen3VLTextAttention = Qwen3VLTextAttention @@ -979,7 +993,30 @@ def apply_veomni_qwen3vl_patch(): hf_qwen3vl.Qwen3VLModel = Qwen3VLModel hf_qwen3vl.Qwen3VLForConditionalGeneration = Qwen3VLForConditionalGeneration + if vision_dp: + apply_vision_dp_patch() + if IS_NPU_AVAILABLE: from .npu_patch import apply_qwen3vl_npu_patch apply_qwen3vl_npu_patch() + + +def apply_vision_dp_patch(): + """Apply Vision DP monkey-patch to VisionModel.forward. + + Instead of patch-level SP (which requires all-to-all per ViT layer for + window attention models), Vision DP distributes *whole images* across SP + ranks and only needs a single all-gather after the ViT. + + Safe to call multiple times — each class is only patched once. + """ + from ...distributed.sequence_parallel.vision_dp import create_dp_vision_forward + + cls = hf_qwen3vl.Qwen3VLVisionModel + if getattr(cls, "_vision_dp_patched", False): + return + original = cls.forward + cls.forward = create_dp_vision_forward(original) + cls._vision_dp_patched = True + logger.info_rank0("Applied Vision DP patch to Qwen3VLVisionModel.forward")