Skip to content

feat(models): shard vision encoder across Ulysses SP ranks#929

Merged
rchardx merged 14 commits intoinclusionAI:mainfrom
aoshen524:feat/vision-dp
Mar 5, 2026
Merged

feat(models): shard vision encoder across Ulysses SP ranks#929
rchardx merged 14 commits intoinclusionAI:mainfrom
aoshen524:feat/vision-dp

Conversation

@aoshen524
Copy link
Contributor

@aoshen524 aoshen524 commented Feb 16, 2026

Summary

  • Vision DP distributes whole images across Ulysses SP ranks for ViT computation, then all-gathers embeddings. Reduces ViT memory from O(total_images) to O(total_images/N) when sp_size > 1.
  • Gated behind vision_dp: bool in TrainEngineConfig — opt-in, default behavior unchanged.
  • Supports Qwen2-VL, Qwen2.5-VL, Qwen3-VL (including Qwen3-VL deepstack).

Design

Aspect Detail
Distribution unit Whole images (not patches) — preserves ViT cu_seqlens semantics
Assignment Load-balanced contiguous: greedy bin-packing by patch count, preserving image order
Gather GatherVisionEmbeddings autograd Function; embedding counts computed locally (no all_gather)
Gradient sync all_reduce(SUM) across SP group in backward before slicing — recovers complete gradient for each image from all sequence shards
Integration point apply_monkey_patch() in ulyssess_patch.py

Files Changed

File Change
areal/utils/vision_dp.py Core Vision DP utilities (new)
areal/models/transformers/ulyssess_patch.py Gate apply_vision_dp_patch() behind vision_dp flag
areal/api/cli_args.py Add vision_dp: bool to TrainEngineConfig
areal/engine/fsdp_engine.py Pass vision_dp=self.config.vision_dp to apply_monkey_patch()
areal/tests/test_vision_dp.py 21 CPU-only unit tests (new)

Usage

--train_engine.vision_dp true --train_engine.ulysses_sp_size 2

Precision Alignment (verl reference experiment)

Validated in verl-project/verl#5230 under controlled conditions (same algorithm shared across frameworks):

Scope Params max_diff mean_diff cosine_sim
vision 390 4.70e-05 2.93e-08 0.9991
language 338 9.50e-08 1.15e-10 1.0020
other 1 9.13e-08 2.25e-13 1.0001
  • Vision DP is numerically lossless: all differences within bf16 precision (~1e-05 max)
  • Language gradients are bitwise identical at pre-clip phase
  • Root cause: all_reduce(SUM) changes FP accumulation order in vision backward

Adapted from verl PR #5230.

🤖 Generated with Claude Code

…ss Ulysses SP ranks

When using Ulysses SP (sp_size > 1), the ViT processes ALL images on every
SP rank redundantly. Vision DP distributes whole images across SP ranks and
all-gathers the embeddings, reducing ViT memory from O(total_images) to
O(total_images/N).

Key design:
- Image-level distribution (not patch-level) to preserve cu_seqlens semantics
- Contiguous assignment: rank k gets images [start_k, ..., end_k], no reorder
  needed after all-gather
- GatherVisionEmbeddings autograd Function with gradient scaling to compensate
  for partial processing before FSDP reduction
- No explicit gradient sync needed: AReaL's FSDP uses dp_sp submesh which
  covers both DP and SP dimensions

Adapted from verl PR #5230.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @aoshen524, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces Vision Data Parallelism (Vision DP) to significantly optimize memory consumption for Vision Transformers (ViT) within Visual Language Models (VLMs) when operating with Ulysses Sequence Parallelism. By distributing images across SP ranks rather than processing all images on every rank, it effectively scales down ViT memory requirements. The integration is seamless for supported VLMs and leverages existing FSDP mechanisms for gradient synchronization.

Highlights

  • Memory Optimization: Implemented Vision Data Parallel (Vision DP) to distribute whole images across Ulysses Sequence Parallel (SP) ranks, reducing Vision Transformer (ViT) memory usage from O(total_images) to O(total_images/N) when sp_size > 1.
  • Automatic Activation: Vision DP is automatically activated for Visual Language Models (VLMs) such as Qwen2-VL, Qwen2.5-VL, and Qwen3-VL through the existing apply_monkey_patch() flow when sp_size > 1.
  • Gradient Synchronization: Eliminated the need for explicit gradient synchronization as AReaL's FSDP utilizes a dp_sp submesh, which inherently covers both Data Parallel (DP) and Sequence Parallel (SP) dimensions, handling gradient reduction automatically.
  • Image Distribution Strategy: Designed Vision DP to distribute whole images contiguously across ranks, preserving ViT cu_seqlens semantics and simplifying the all-gather process without requiring reordering.
  • Gradient Scaling: Incorporated a GatherVisionEmbeddings autograd Function that includes gradient scaling in the backward pass to correctly compensate for partial image processing on each rank.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • areal/models/transformers/ulyssess_patch.py
    • Integrated the apply_vision_dp_patch function to enable Vision DP for supported VLM models.
  • areal/tests/test_vision_dp.py
    • Added comprehensive CPU-only unit tests for the Vision Data Parallel utilities, covering image patch counting, embedding counting, image assignment to ranks, and local input preparation.
  • areal/utils/vision_dp.py
    • Added a new module containing core Vision Data Parallel utilities, including functions for calculating image patch and embedding counts, assigning images to distributed ranks, preparing local vision inputs, and a custom GatherVisionEmbeddings autograd function for distributed embedding collection with gradient scaling.
Activity
  • Initial implementation adapted from verl PR #5230.
  • Automated generation of parts of the code using Claude Code.
  • Ruff linting and formatting checks have passed.
  • Unit tests for areal/tests/test_vision_dp.py are pending execution in the CI container.
  • Multi-GPU correctness tests with VLM models are pending.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Vision Data Parallelism (Vision DP) to distribute Vision Transformer (ViT) computation across sequence parallelism ranks, which is a great feature for reducing memory usage in Vision-Language Models. The implementation is well-structured, with the core logic encapsulated in areal/utils/vision_dp.py, integration via monkey-patching, and comprehensive unit tests. The code is clean and well-documented. I have one suggestion to improve the performance and readability of the patch extraction logic in prepare_local_vision_inputs by leveraging tensorized operations.

Copy link
Collaborator

@rchardx rchardx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clean implementation! The design is well-documented and the contiguous assignment approach is elegant. I have a few questions and concerns I'd like to discuss — particularly around gradient flow in the backward pass. Please see inline comments.

1. Fix gradient routing bug: add all_reduce(SUM) in backward before
   slicing to aggregate gradient contributions from all SP ranks.
   Remove the incorrect `* dp_size` scaling and `grad_scaler` param.

2. Load balancing: replace count-based chunking with greedy contiguous
   bin-packing by patch load (dynamic target per rank).

3. Remove unnecessary all_gather: compute per-rank embedding counts
   locally since grid_thw is replicated on all ranks.

4. Add idempotency guard: extract _patch_vision_class() helper with
   _vision_dp_patched attribute to prevent double-wrapping.

5. Remove unreachable qwen3_vl_moe code block.

6. Move grid_thw to CPU at dp_vision_forward entry to avoid repeated
   GPU→CPU syncs in metadata helpers.

7. Replace Python loops with tensor slicing in prepare_local_vision_inputs
   (cumsum + contiguous slice).

8. Improve tests: rename to test_<what>_<condition>_<expected> convention,
   add load balance test, add gather_vision_embeddings passthrough test,
   update assignments to use contiguous indices.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
aoshen524 and others added 3 commits March 3, 2026 22:47
…d contiguous guard

- Trim verbose docstrings to concise one-liners
- Delete dead store ctx.hidden_size (written in forward, never read in backward)
- Simplify hidden_size detection: self.config.out_hidden_size
- Add requires_grad_() for empty rank to participate in backward all_reduce
- Add .contiguous() guard before all_reduce (NCCL requirement)
- Reuse get_image_patch_counts in spatial_merge_size==1 path

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Detect Qwen3-VL via model attribute (hasattr deepstack_merger_list)
instead of return type, so empty ranks that skip original_forward
still create matching empty deepstack tensors and participate in
all-gather — preventing NCCL deadlock.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add `vision_dp` field to TrainEngineConfig (default False). Vision DP
is now only applied when explicitly enabled, instead of unconditionally
when context_parallel_size > 1.

Config flow: TrainEngineConfig.vision_dp → FSDPEngine → apply_monkey_patch
→ apply_vision_dp_patch().

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
aoshen524 and others added 9 commits March 4, 2026 22:38
…-check

The assertion `tensor[a:b].shape[0] == b - a` always passes by Python
slicing semantics. Replace with an independent verification path:
compute expected patches from `get_image_patch_counts(local_grid_thw)`
instead of from the same `offsets` used for slicing.

Addresses review comment: https://github.com/inclusionAI/AReaL/pull/929/files#r2863271094

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Rename tests to follow `test_<what>_<condition>_<expected>()` convention
  for self-documenting test output in CI logs
- Add TestCreateDpVisionForward: verify sp_size<=1 calls original_forward
- Add TestPatchVisionClass: verify forward replaced + idempotency guard
- Add TestApplyVisionDpPatch: verify ImportError does not crash
- Add gather_vision_embeddings same-storage passthrough test
- Add embedding_counts empty input test
- Add contiguous coverage test across multiple dp_sizes

Addresses review comments:
- https://github.com/inclusionAI/AReaL/pull/929/files#r2845688349
- https://github.com/inclusionAI/AReaL/pull/929/files#r2845665917

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The "vision_dp" name was misleading — the feature shards vision
encoding across SP (sequence parallel) ranks, not DP ranks.
Rename globally for clarity, and fix several bugs found in review.

Key changes:
- Rename vision_dp -> shard_vision_across_sp across config, code, docs, tests
- Fix NCCL deadlock: add .requires_grad_() on empty deepstack tensors
- Fix perf: pass grid_thw_cpu instead of GPU tensor to avoid 4x cudaStreamSync
- Fix latent issue: add .contiguous() guard before all_gather on no-padding path
- Add warnings for silent no-ops (sp_size<=1, non-VLM, zero classes patched)

Refs: inclusionAI#929
Move vision_sp_shard.py from areal/utils/ to areal/models/fsdp/ to
co-locate with ulysses.py. Also move shard_vision_across_sp config
field from TrainEngineConfig to FSDPEngineConfig where it belongs.

Key changes:
- Move areal/utils/vision_sp_shard.py -> areal/models/fsdp/vision_sp_shard.py
- Move shard_vision_across_sp field to FSDPEngineConfig, access via self.config.fsdp
- Rename internal helpers with _ prefix (not part of public API)
- Add defensive .cpu() guard in _get_image_patch_counts/_get_image_embedding_counts
- Polish module docstring
- Update all imports, mock paths, and docs tables

Refs: inclusionAI#929
Extract _unpack_deepstack helper, simplify spatial_merge_size
lookup, table-driven _VISION_CLASSES registry, remove redundant
.cpu() guards and tautological assertion. Trim private docstrings.

Key changes:
- Extract _unpack_deepstack() to isolate Qwen3-VL deepstack logic
- Simplify spatial_merge_size to getattr(self, "spatial_merge_size", 1)
- Fix hidden_size fallback: out_hidden_size or hidden_size
- Remove redundant .cpu() in _get_image_patch_counts/_get_image_embedding_counts
- Remove tautological cross-check assertion in _prepare_local_vision_inputs
- Refactor apply_vision_sp_shard_patch to table-driven _VISION_CLASSES registry
- Trim multi-line docstrings on private functions to concise one-liners

Refs: inclusionAI#929
…ests

vision_sp_shard patches transformers model classes, so it belongs
in models/transformers/ not models/fsdp/. Also improve test quality.

Key changes:
- Move vision_sp_shard.py from models/fsdp/ to models/transformers/
- Update import paths in ulyssess_patch.py and test file
- Add _unpack_deepstack tests (3 code paths incl. NCCL deadlock fix)
- Extract _assert_all_images_assigned helper to DRY 4 callsites
- Add merge_size=1 equivalence test for refactored code path
- Merge duplicate gather test, simplify apply_patch test
- Remove "Adapted from verl" attribution from test docstring

Refs: inclusionAI#929
@rchardx rchardx changed the title [Feature] Vision Data Parallel: distribute ViT across Ulysses SP ranks feat(models): shard vision encoder across Ulysses SP ranks Mar 5, 2026
Copy link
Collaborator

@rchardx rchardx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@rchardx rchardx merged commit 036ab16 into inclusionAI:main Mar 5, 2026
5 checks passed
dingzhiqiang pushed a commit that referenced this pull request Mar 16, 2026
When using Ulysses Sequence Parallelism with VLMs, every SP rank
redundantly runs the full vision encoder. This adds a
`shard_vision_across_sp` option that distributes whole images across
SP ranks, runs ViT locally, and all-gathers the embeddings -
eliminating redundant computation while preserving gradient
correctness via all_reduce(SUM) in backward.

Key changes:
- Add vision_sp_shard.py with greedy contiguous image assignment,
  padded all-gather, and custom autograd backward
- Support Qwen2-VL, Qwen2.5-VL, and Qwen3-VL (incl. deepstack)
- Add shard_vision_across_sp flag to FSDPEngineConfig
  (effective only when context_parallel_size > 1)
- Monkey-patch VisionTransformer.forward via table-driven registry
- Add 31 CPU-only unit tests covering all helper functions,
  deepstack unpacking, patching idempotency, and integration

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Wentai Zhang <zhangwentai.zwt@antgroup.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants