[Feat]: support VAE patch parallelism#756
[Feat]: support VAE patch parallelism#756dongbo910220 wants to merge 37 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: f575a6d475
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
|
is there any speed comparison |
lishunyang12
left a comment
There was a problem hiding this comment.
Can you also test it on text_to_image? like wan2.2
|
It is indeed an important feature for diffusion models with VAE. Nice job! I notice that there is an exisiting interface for VAE patch parallelism vllm-omni/vllm_omni/diffusion/distributed/parallel_state.py Lines 563 to 573 in 7e14c94 If vae_parallel_size is greater than 0, it will initialize another communication group in addition to exisiting vllm-omni/vllm_omni/diffusion/distributed/parallel_state.py Lines 711 to 713 in 7e14c94 In your current implementation, vae parallelism uses exisiting group = get_dit_group()
world_size = dist.get_world_size(group)
rank = dist.get_rank(group)What do you think is better? Initializing another communication group for VAE patch parallelism only or share the communication group with other parallelism methods (such as SP and TP)? |
it's mainly for reduce peak memory usage |
|
Can you also list the memory reduction and speed gain of this PR? |
|
do we need to apply it one by one |
Currently, this PR focuses on the MVP implementation specifically for Z-Image to establish the baseline and verify the parallelism logic. The VAE patch parallelism logic is implemented within ZImagePipeline for now. |
|
@hsliuustc0106 and @wtomin Thanks for asking! I have updated the PR description with detailed performance and memory benchmarks. |
To clarify:
Does this align with your expectation? |
Great catch regarding the existing vae_parallel_size! You are right that vae_parallel_size/init_vae_group implies a disaggregated architecture (e.g., using dedicated ranks for VAE separate from DiT ranks). However, this PR implements in-place patch parallelism, where the same set of DiT workers share the VAE workload. Since we are reusing the same physical ranks, using the existing dit_group is the most logical and simplest approach for this MVP. Creating a dedicated sub-group (e.g., a vae_patch_group containing only active ranks) could be an optimization for the future (especially if world_size >> pp_size), but reusing dit_group is sufficient and correct for now. |
78d1832 to
71d2a63
Compare
SamitHuang
left a comment
There was a problem hiding this comment.
it's not yet a general solution. how about set to draft state at first?
c822f97 to
71d2a63
Compare
|
Besides the implementation in this PR, may I also recommend you the VAE Patch Parallelism in x-DiT? I think the x-DiT implemention is lossless, we can support both the lossy (current PR) and lossless implementations in vLLM-Omni. |
ea79ec1 to
d30cd54
Compare
Decouple the VAE Parallelism logic from Z-Image and extract it into a generic utility class. |
Thanks for the recommendation! I checked x-DiT's implementation and it's definitely interesting. Just to clarify on the 'lossless' part: our current implementation is actually numerically lossless (diff = 0) for large images (e.g., 1536x1536) where the standard Overlap+Blend tiling kicks in. For smaller images using the Halo+Crop strategy, the difference is negligible (mean_abs < 0.003). Regarding supporting x-DiT's method: I agree it would be great to support both! |
Signed-off-by: dongbo910220 <1275604947@qq.com>
Signed-off-by: dongbo910220 <1275604947@qq.com>
Signed-off-by: dongbo910220 <1275604947@qq.com>
Signed-off-by: dongbo910220 <1275604947@qq.com>
Signed-off-by: dongbo910220 <1275604947@qq.com>
Signed-off-by: dongbo910220 <1275604947@qq.com>
Signed-off-by: dongbo910220 <1275604947@qq.com>
Signed-off-by: dongbo910220 <1275604947@qq.com>
3433376 to
36869f6
Compare
Signed-off-by: dongbo910220 <1275604947@qq.com>
Signed-off-by: dongbo910220 <1275604947@qq.com>
Signed-off-by: dongbo910220 <1275604947@qq.com>
Signed-off-by: dongbo910220 <1275604947@qq.com>
ZJY0516
left a comment
There was a problem hiding this comment.
I have a little problem. If we reuse the distributed group, does that mean VAЕ parallelism becomes unavailable when other parallelism methods are not in use?
Signed-off-by: dongbo910220 <1275604947@qq.com>
Signed-off-by: dongbo910220 <1275604947@qq.com>
Good point. Yes. VAE patch parallelism reuses the existing diffusion ProcessGroup (dit_group) and does not spawn extra ranks. If dit_group.world_size == 1 (no TP/SP/CFG/etc.), there is nothing to shard across, so it would not reduce peak decoder memory; we therefore fall back to the original vae.decode to avoid extra overhead and potential numerical differences. |
Signed-off-by: dongbo910220 <1275604947@qq.com> # Conflicts: # docs/user_guide/diffusion_acceleration.md
Signed-off-by: dongbo910220 <1275604947@qq.com>
Signed-off-by: dongbo910220 <1275604947@qq.com>
Signed-off-by: dongbo910220 <1275604947@qq.com>
Signed-off-by: dongbo910220 <1275604947@qq.com>
Signed-off-by: dongbo910220 <1275604947@qq.com>
Signed-off-by: dongbo910220 <1275604947@qq.com>
|
LGTM, I think this PR is ready :) |
Purpose
Z-Image: In multi-GPU scenarios, each rank decodes a subset of tiles. Rank 0 gathers and stitches them together. This reduces the peak VAE decode memory usage on workers.Technical Design & Documentation
For a detailed explanation of the VAE patch parallelism method, including when to use _distributed_tiled_decode vs _distributed_patch_decode and their differences, please refer to this Document.
Test Plan
pytest -q tests/diffusion/models/z_image/test_zimage_tp_constraints.py
pytest -q tests/e2e/offline_inference/test_zimage_tensor_parallel.py
Test Result
all passed
Benchmarks (V100 x 2)
Config: 1024x1024, 6 steps, TP=2, tiling enabled.
VAE Decode Statistics
1. Memory Reduction (Peak VRAM Reserved)
By distributing the tiling workload, the peak reserved memory is significantly reduced across both ranks.
2. Speed Gain (VAE Decode Latency)
The parallel implementation provides a substantial speedup by leveraging multi-GPU compute for the decoding phase.
Qualitative Comparison (VAE Parallelism)
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)