Skip to content

[Feat]: support VAE patch parallelism#756

Open
dongbo910220 wants to merge 37 commits intovllm-project:mainfrom
dongbo910220:patch_vae_parallelism
Open

[Feat]: support VAE patch parallelism#756
dongbo910220 wants to merge 37 commits intovllm-project:mainfrom
dongbo910220:patch_vae_parallelism

Conversation

@dongbo910220
Copy link
Contributor

@dongbo910220 dongbo910220 commented Jan 12, 2026

Purpose

  • Implement VAE patch/tile parallelism (pp) for Z-Image: In multi-GPU scenarios, each rank decodes a subset of tiles. Rank 0 gathers and stitches them together. This reduces the peak VAE decode memory usage on workers.

Technical Design & Documentation

For a detailed explanation of the VAE patch parallelism method, including when to use _distributed_tiled_decode vs _distributed_patch_decode and their differences, please refer to this Document.

Test Plan

pytest -q tests/diffusion/models/z_image/test_zimage_tp_constraints.py
pytest -q tests/e2e/offline_inference/test_zimage_tensor_parallel.py

Test Result

all passed

Benchmarks (V100 x 2)

Config: 1024x1024, 6 steps, TP=2, tiling enabled.

VAE Decode Statistics

1. Memory Reduction (Peak VRAM Reserved)

By distributing the tiling workload, the peak reserved memory is significantly reduced across both ranks.

Mode Peak VRAM Reserved (GPU0) Peak VRAM Reserved (GPU1) Reduction (per GPU)
pp=1 (Off) 27.004 GiB 27.004 GiB -
pp=2 (On) 23.316 GiB 23.316 GiB -3.688 GiB (~13.7%)

2. Speed Gain (VAE Decode Latency)

The parallel implementation provides a substantial speedup by leveraging multi-GPU compute for the decoding phase.

Metric pp=1 (Serial Avg) pp=2 (Parallel Avg) Improvement
VAE Decode Time ~1356.0 ms ~991.2 ms -364.8 ms (~26.9% Speedup)

Qualitative Comparison (VAE Parallelism)

pp=1 Output (Serial) pp=2 Output (Parallel)
zimage_tp2_pp1_1024_steps6 zimage_tp2_pp2_1024_steps6
1024×1024, steps=6, TP=2, tiling Consistent output across parallelized VAE ranks

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

@dongbo910220 dongbo910220 changed the title diffusion(z-image): support VAE patch parallelism [Feat]: support VAE patch parallelism Jan 12, 2026
@dongbo910220 dongbo910220 marked this pull request as ready for review January 12, 2026 20:33
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: f575a6d475

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@hsliuustc0106
Copy link
Collaborator

is there any speed comparison

Copy link
Contributor

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also test it on text_to_image? like wan2.2

@wtomin
Copy link
Contributor

wtomin commented Jan 13, 2026

It is indeed an important feature for diffusion models with VAE. Nice job!

I notice that there is an exisiting interface for VAE patch parallelism vae_parallel_size:

def initialize_model_parallel(
data_parallel_size: int = 1,
cfg_parallel_size: int = 1,
sequence_parallel_size: int | None = None,
ulysses_degree: int = 1,
ring_degree: int = 1,
tensor_parallel_size: int = 1,
pipeline_parallel_size: int = 1,
vae_parallel_size: int = 0,
backend: str | None = None,
) -> None:

If vae_parallel_size is greater than 0, it will initialize another communication group in addition to exisiting dit_group:

if vae_parallel_size > 0:
init_vae_group(dit_parallel_size, vae_parallel_size, backend)
init_dit_group(dit_parallel_size, backend)

In your current implementation, vae parallelism uses exisiting dit_group

group = get_dit_group()
world_size = dist.get_world_size(group)
rank = dist.get_rank(group)

What do you think is better? Initializing another communication group for VAE patch parallelism only or share the communication group with other parallelism methods (such as SP and TP)?

@ZJY0516
Copy link
Collaborator

ZJY0516 commented Jan 13, 2026

is there any speed comparison

it's mainly for reduce peak memory usage

@wtomin
Copy link
Contributor

wtomin commented Jan 13, 2026

Can you also list the memory reduction and speed gain of this PR?

@hsliuustc0106
Copy link
Collaborator

do we need to apply it one by one

@dongbo910220
Copy link
Contributor Author

Can you also test it on text_to_image? like wan2.2

Currently, this PR focuses on the MVP implementation specifically for Z-Image to establish the baseline and verify the parallelism logic. The VAE patch parallelism logic is implemented within ZImagePipeline for now.

@dongbo910220
Copy link
Contributor Author

dongbo910220 commented Jan 13, 2026

@hsliuustc0106 and @wtomin Thanks for asking! I have updated the PR description with detailed performance and memory benchmarks.
The latest performance data shows a ~27% speedup in VAE decode latency (dropping from ~1356ms to ~991ms).

@dongbo910220
Copy link
Contributor Author

dongbo910220 commented Jan 13, 2026

do we need to apply it one by one

To clarify:

  1. For this MVP: Yes, it is currently implemented only in ZImagePipeline.
  2. For the future: The core logic (_distributed_tiled_decode) could be extracted into a shared utility (e.g., in vllm_omni.diffusion.utils) in the next Phase. This will allow most standard diffusers pipelines to adopt it easily without copy-pasting code 'one by one'.

Does this align with your expectation?

@dongbo910220
Copy link
Contributor Author

dongbo910220 commented Jan 13, 2026

What do you think is better? Initializing another communication group for VAE patch parallelism only or share the communication group with other parallelism methods (such as SP and TP)?

Great catch regarding the existing vae_parallel_size!

You are right that vae_parallel_size/init_vae_group implies a disaggregated architecture (e.g., using dedicated ranks for VAE separate from DiT ranks).

However, this PR implements in-place patch parallelism, where the same set of DiT workers share the VAE workload. Since we are reusing the same physical ranks, using the existing dit_group is the most logical and simplest approach for this MVP.

Creating a dedicated sub-group (e.g., a vae_patch_group containing only active ranks) could be an optimization for the future (especially if world_size >> pp_size), but reusing dit_group is sufficient and correct for now.

@dongbo910220 dongbo910220 force-pushed the patch_vae_parallelism branch 3 times, most recently from 78d1832 to 71d2a63 Compare January 13, 2026 22:24
Copy link
Collaborator

@SamitHuang SamitHuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not yet a general solution. how about set to draft state at first?

@dongbo910220 dongbo910220 marked this pull request as draft January 14, 2026 07:01
@dongbo910220 dongbo910220 force-pushed the patch_vae_parallelism branch 3 times, most recently from c822f97 to 71d2a63 Compare January 14, 2026 08:38
@wtomin
Copy link
Contributor

wtomin commented Jan 15, 2026

Besides the implementation in this PR, may I also recommend you the VAE Patch Parallelism in x-DiT? I think the x-DiT implemention is lossless, we can support both the lossy (current PR) and lossless implementations in vLLM-Omni.

@dongbo910220 dongbo910220 force-pushed the patch_vae_parallelism branch from ea79ec1 to d30cd54 Compare January 15, 2026 17:38
@dongbo910220
Copy link
Contributor Author

it's not yet a general solution. how about set to draft state at first?

Decouple the VAE Parallelism logic from Z-Image and extract it into a generic utility class.

@dongbo910220
Copy link
Contributor Author

Besides the implementation in this PR, may I also recommend you the VAE Patch Parallelism in x-DiT? I think the x-DiT implemention is lossless, we can support both the lossy (current PR) and lossless implementations in vLLM-Omni.

Thanks for the recommendation! I checked x-DiT's implementation and it's definitely interesting.

Just to clarify on the 'lossless' part: our current implementation is actually numerically lossless (diff = 0) for large images (e.g., 1536x1536) where the standard Overlap+Blend tiling kicks in. For smaller images using the Halo+Crop strategy, the difference is negligible (mean_abs < 0.003).

Regarding supporting x-DiT's method: I agree it would be great to support both!
This PR (Current): Focuses on a lightweight, non-intrusive* approach compatible with any diffusers.AutoencoderKL without modifying model internals.
Future:* We can certainly explore adding x-DiT's method (which requires deeper integration like custom GroupNorm/Conv2d) as an alternative backend in a future PR.

Signed-off-by: dongbo910220 <1275604947@qq.com>
Signed-off-by: dongbo910220 <1275604947@qq.com>
Signed-off-by: dongbo910220 <1275604947@qq.com>
Signed-off-by: dongbo910220 <1275604947@qq.com>
Signed-off-by: dongbo910220 <1275604947@qq.com>
Signed-off-by: dongbo910220 <1275604947@qq.com>
Signed-off-by: dongbo910220 <1275604947@qq.com>
Signed-off-by: dongbo910220 <1275604947@qq.com>
@dongbo910220 dongbo910220 force-pushed the patch_vae_parallelism branch from 3433376 to 36869f6 Compare January 28, 2026 04:47
Signed-off-by: dongbo910220 <1275604947@qq.com>
Signed-off-by: dongbo910220 <1275604947@qq.com>
Signed-off-by: dongbo910220 <1275604947@qq.com>
Signed-off-by: dongbo910220 <1275604947@qq.com>
Copy link
Collaborator

@ZJY0516 ZJY0516 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a little problem. If we reuse the distributed group, does that mean VAЕ parallelism becomes unavailable when other parallelism methods are not in use?

Signed-off-by: dongbo910220 <1275604947@qq.com>
Signed-off-by: dongbo910220 <1275604947@qq.com>
@dongbo910220
Copy link
Contributor Author

I have a little problem. If we reuse the distributed group, does that mean VAЕ parallelism becomes unavailable when other parallelism methods are not in use?

Good point. Yes. VAE patch parallelism reuses the existing diffusion ProcessGroup (dit_group) and does not spawn extra ranks. If dit_group.world_size == 1 (no TP/SP/CFG/etc.), there is nothing to shard across, so it would not reduce peak decoder memory; we therefore fall back to the original vae.decode to avoid extra overhead and potential numerical differences.

Signed-off-by: dongbo910220 <1275604947@qq.com>

# Conflicts:
#	docs/user_guide/diffusion_acceleration.md
Copy link
Collaborator

@ZJY0516 ZJY0516 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall, lgtm

@ZJY0516
Copy link
Collaborator

ZJY0516 commented Feb 5, 2026

cc @wtomin @SamitHuang @mxuax

Signed-off-by: dongbo910220 <1275604947@qq.com>
Signed-off-by: dongbo910220 <1275604947@qq.com>
Signed-off-by: dongbo910220 <1275604947@qq.com>
Signed-off-by: dongbo910220 <1275604947@qq.com>
Signed-off-by: dongbo910220 <1275604947@qq.com>
Signed-off-by: dongbo910220 <1275604947@qq.com>
@mxuax
Copy link
Contributor

mxuax commented Feb 6, 2026

LGTM, I think this PR is ready :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants