Skip to content
Open
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
4c7f60f
diffusion(z-image): add VAE patch parallel decode
dongbo910220 Jan 13, 2026
5bddc7c
tests/examples: cover VAE patch parallelism
dongbo910220 Jan 13, 2026
f72343f
examples: keep VAE slicing opt-in
dongbo910220 Jan 13, 2026
4aa8819
examples: keep VAE tiling opt-in
dongbo910220 Jan 13, 2026
01dc18c
diffusion(z-image): prototype DistVAE-style patch decode
dongbo910220 Jan 13, 2026
dfccf30
diffusion(z-image): extend patch decode and balance tiles
dongbo910220 Jan 13, 2026
4997ab9
diffusion(z-image): clarify patch decode comment
dongbo910220 Jan 13, 2026
3bb5cca
diffusion: refactor VAE patch parallelism (ref_dit)
dongbo910220 Jan 14, 2026
75c8618
diffusion: dedupe VAE patch-parallel helpers
dongbo910220 Jan 15, 2026
a8b0e96
tests: add unit coverage for VAE patch parallelism helpers
dongbo910220 Jan 15, 2026
6dd3e3a
diffusion: inject VAE decode profiling wrapper
dongbo910220 Jan 15, 2026
02dabb0
diffusion: remove VAE decode profiling hooks
dongbo910220 Jan 15, 2026
d730c04
diffusion: allowlist VAE patch parallel install
dongbo910220 Jan 15, 2026
b63544e
diffusion: drop env override for VAE patch parallel size
dongbo910220 Jan 16, 2026
432276f
diffusion: remove legacy vae_parallel_size group
dongbo910220 Jan 16, 2026
5ee4f97
diffusion: document and reorganize VAE patch parallelism
dongbo910220 Jan 20, 2026
147a690
docs: add tensor parallelism quickstart
dongbo910220 Jan 20, 2026
a7f5feb
diffusion: auto-enable VAE tiling for vae pp
dongbo910220 Jan 20, 2026
de1c634
docs: expand diffusion acceleration support table
dongbo910220 Jan 20, 2026
31bfb43
docs: add VAE patch parallel example and align tables
dongbo910220 Jan 20, 2026
1536fff
tests: relocate VAE patch parallel unit tests
dongbo910220 Jan 20, 2026
9ae119f
tests: align Z-Image TP e2e size with upstream
dongbo910220 Jan 20, 2026
20e1d3a
cleanup: move VAE pp unit test and remove empty vae pkg
dongbo910220 Jan 21, 2026
36869f6
style: apply ruff format
dongbo910220 Jan 28, 2026
61afcb8
Merge origin/main into patch_vae_parallelism
dongbo910220 Jan 28, 2026
69eeb79
tests: fix zimage platform checks
dongbo910220 Jan 28, 2026
d187eb9
Merge origin/main into patch_vae_parallelism
dongbo910220 Feb 4, 2026
802da55
fix: vllm 0.14 multimodal import
dongbo910220 Feb 4, 2026
2ee0c12
refactor: clarify VAE patch parallel naming
dongbo910220 Feb 4, 2026
d923536
test: always enforce eager in zimage e2e
dongbo910220 Feb 4, 2026
c95c0ff
Merge remote-tracking branch 'origin/main' into patch_vae_parallelism
dongbo910220 Feb 4, 2026
a8f1024
refactor: simplify vae patch parallel size access
dongbo910220 Feb 5, 2026
7f2e3ef
test: rename zimage parallel e2e file
dongbo910220 Feb 5, 2026
b954d2d
test: clarify zimage parallelism e2e coverage
dongbo910220 Feb 5, 2026
cdba974
test: enable compile for zimage e2e
dongbo910220 Feb 6, 2026
d7e55fb
test: trim zimage e2e docstring
dongbo910220 Feb 6, 2026
5541d13
test: drop eager guard in zimage e2e
dongbo910220 Feb 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 55 additions & 14 deletions docs/user_guide/diffusion/parallelism_acceleration.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,25 @@ The following parallelism methods are currently supported in vLLM-Omni:

4. [Tensor Parallelism](#tensor-parallelism): Tensor parallelism shards model weights across devices. This can reduce per-GPU memory usage. Note that for diffusion models we currently shard the majority of layers within the DiT.

5. [VAE Patch Parallelism](#vae-patch-parallelism): VAE patch parallelism shards VAE decode/encode spatially across ranks. This can reduce the peak memory of VAE decode and (depending on resolution and communication overhead) speed up VAE decode.

The following table shows which models are currently supported by parallelism method:

### ImageGen

| Model | Model Identifier | Ulysses-SP | Ring-SP | CFG-Parallel | Tensor-Parallel |
|--------------------------|--------------------------------------|:----------:|:-------:|:------------:|:---------------:|
| **LongCat-Image** | `meituan-longcat/LongCat-Image` | ✅ | ✅ | ❌ | ✅ |
| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ✅ | ✅ | ❌ | ✅ |
| **Ovis-Image** | `OvisAI/Ovis-Image` | ❌ | ❌ | ❌ | ❌ |
| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | ✅ | ✅ |
| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ✅ | ✅ | ✅ | ✅ |
| **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` | ✅ | ✅ | ✅ | ✅ |
| **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ✅ | ✅ | ✅ | ✅ |
| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ✅ | ✅ | ❌ | ✅ (TP=2 only) |
| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ❌ | ❌ | ❌ |
| **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | ❌ | ❌ | ❌ | ✅ |
| **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ❌ | ❌ | ✅ |
| Model | Model Identifier | Ulysses-SP | Ring-Attention | CFG-Parallel | Tensor-Parallel | VAE-Patch-Parallel |
|--------------------------|--------------------------------------|:----------:|:--------------:|:------------:|:---------------:|:------------------:|
| **LongCat-Image** | `meituan-longcat/LongCat-Image` | ✅ | | ❌ | ✅ | ❌ |
| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ✅ | | ❌ | ✅ | ❌ |
| **Ovis-Image** | `OvisAI/Ovis-Image` | ❌ | | ❌ | ❌ | ❌ |
| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | | ✅ | ✅ | ❌ |
| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ✅ | | ✅ | ✅ | ❌ |
| **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` | ✅ | | ✅ | ✅ | ❌ |
| **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ✅ | | ✅ | ✅ | ❌ |
| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ✅ | | ❌ | ✅ (TP=2 only) | ✅ |
| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | | ❌ | ❌ | ❌ |
| **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | ❌ | | ❌ | ✅ | ❌ |
| **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | | ❌ | ✅ | ❌ |

!!! note "TP Limitations for Diffusion Models"
We currently implement Tensor Parallelism (TP) only for the DiT (Diffusion Transformer) blocks. This is because the `text_encoder` component in vLLM-Omni uses the original Transformers implementation, which does not yet support TP.
Expand All @@ -47,7 +49,7 @@ The following table shows which models are currently supported by parallelism me

### VideoGen

| Model | Model Identifier | Ulysses-SP | Ring-SP | Tensor-Parallel |
| Model | Model Identifier | Ulysses-SP | Ring-Attention | Tensor-Parallel |
|-------|------------------|------------|---------|--------------------------|
| **Wan2.2** | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | ✅ | ✅ | ❌ |

Expand Down Expand Up @@ -76,6 +78,45 @@ outputs = omni.generate(
)
```

### VAE Patch Parallelism

VAE patch parallelism distributes the VAE decode/encode workload across multiple ranks by splitting the latent spatially. It is configured via `DiffusionParallelConfig.vae_patch_parallel_size` and can be combined with other parallelism methods (e.g., TP).

!!! note "Enablement and feature gate"
- VAE patch parallelism is currently **enabled only for validated pipelines** (currently: `Tongyi-MAI/Z-Image-Turbo`).
- If `vae_patch_parallel_size > 1` is set for a validated pipeline, vLLM-Omni will automatically enable `vae_use_tiling` as a safety gate. (We use `vae_use_tiling` because it indicates the VAE supports diffusers tiling parameters like `tile_latent_min_size` and `tile_overlap_factor`.)

#### Offline Inference

```python
from vllm_omni import Omni
from vllm_omni.diffusion.data import DiffusionParallelConfig

omni = Omni(
model="Tongyi-MAI/Z-Image-Turbo",
parallel_config=DiffusionParallelConfig(
tensor_parallel_size=2,
vae_patch_parallel_size=2,
),
vae_use_tiling=True,
)

outputs = omni.generate(
prompt="a cat reading a book",
num_inference_steps=9,
width=1024,
height=1024,
)
```

#### How it works (method selection)

VAE patch parallelism automatically selects between two internal decode methods based on whether diffusers tiling would kick in:

- `_distributed_tiled_decode`: Used when the latent spatial size exceeds `vae.tile_latent_min_size` (i.e., diffusers tiled decode). Each rank decodes a subset of tiles; rank0 gathers and runs the same overlap+blend+stitch logic as diffusers. This matches the single-rank diffusers tiled output.

- `_distributed_patch_decode`: Used when diffusers tiling would not kick in. Each rank decodes a grid patch expanded with a latent-space halo; then rank0 gathers the cropped core patches and stitches them into the full image. This path has no blending and can introduce small numerical differences compared to the non-parallel decode.

### Sequence Parallelism

#### Ulysses-SP
Expand Down
73 changes: 60 additions & 13 deletions docs/user_guide/diffusion_acceleration.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ vLLM-Omni also supports parallelism methods for diffusion models, including:

3. [CFG-Parallel](diffusion/parallelism_acceleration.md#cfg-parallel) - runs the positive/negative prompts of classifier-free guidance (CFG) on different devices, then merges on a single device to perform the scheduler step.

4. [Tensor Parallelism](diffusion/parallelism_acceleration.md#tensor-parallelism) - shards DiT weights across devices to reduce per-GPU memory usage.

5. [VAE Patch Parallelism](diffusion/parallelism_acceleration.md#vae-patch-parallelism) - shards VAE decode/encode spatially across ranks to reduce VAE peak memory (and can speed up VAE decode).

## Quick Comparison

### Cache Methods
Expand All @@ -37,19 +41,19 @@ The following table shows which models are currently supported by each accelerat

### ImageGen

| Model | Model Identifier | TeaCache | Cache-DiT | Ulysses-SP | Ring-Attention | CFG-Parallel |
|-------|------------------|:----------:|:-----------:|:-----------:|:----------------:|:----------------:|
| **LongCat-Image** | `meituan-longcat/LongCat-Image` | ❌ | ✅ | ❌ | ❌ | ✅ |
| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ❌ | ✅ | ❌ | ❌ | ✅ |
| **Ovis-Image** | `OvisAI/Ovis-Image` | ❌ | ✅ | ❌ | ❌ | ✅ |
| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | ✅ | ✅ | ✅ |
| **Qwen-Image-2512** | `Qwen/Qwen-Image-2512` | ✅ | ✅ | ✅ | ✅ | ✅ |
| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ✅ | ✅ | ✅ | ✅ | ✅ |
| **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` | ✅ | ✅ | ✅ | ✅ | ✅ |
| **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ❌ | ✅ | ✅ | ✅ | ✅ |
| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ✅ | ✅ | ❌ | ❌ | ❌ |
| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ✅ | ❌ | ❌ | ✅ |
| **Bagel** | `ByteDance-Seed/BAGEL-7B-MoT` | ✅ | ✅ | ❌ | ❌ | ❌ |
| Model | Model Identifier | TeaCache | Cache-DiT | Ulysses-SP | Ring-Attention | CFG-Parallel | Tensor-Parallel | VAE-Patch-Parallel |
|-------|------------------|:--------:|:---------:|:----------:|:--------------:|:------------:|:---------------:|:------------------:|
| **LongCat-Image** | `meituan-longcat/LongCat-Image` | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
| **Ovis-Image** | `OvisAI/Ovis-Image` | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| **Qwen-Image-2512** | `Qwen/Qwen-Image-2512` | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ (TP=2 only) | ✅ |
| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
| **Bagel** | `ByteDance-Seed/BAGEL-7B-MoT` | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |

### VideoGen

Expand Down Expand Up @@ -200,6 +204,47 @@ outputs = omni.generate(
)
```

### Using Tensor Parallelism

```python
from vllm_omni import Omni
from vllm_omni.diffusion.data import DiffusionParallelConfig

omni = Omni(
model="Tongyi-MAI/Z-Image-Turbo",
parallel_config=DiffusionParallelConfig(tensor_parallel_size=2),
)

outputs = omni.generate(
prompt="a cat reading a book",
num_inference_steps=9,
width=512,
height=512,
)
```

### Using VAE Patch Parallelism

```python
from vllm_omni import Omni
from vllm_omni.diffusion.data import DiffusionParallelConfig

omni = Omni(
model="Tongyi-MAI/Z-Image-Turbo",
parallel_config=DiffusionParallelConfig(
tensor_parallel_size=2,
vae_patch_parallel_size=2,
),
)

outputs = omni.generate(
prompt="a cat reading a book",
num_inference_steps=9,
width=1024,
height=1024,
)
```

### Using CFG-Parallel

Run image-to-image:
Expand Down Expand Up @@ -232,5 +277,7 @@ For detailed information on each acceleration method:

- **[TeaCache Guide](diffusion/teacache.md)** - Complete TeaCache documentation, configuration options, and best practices
- **[Cache-DiT Acceleration Guide](diffusion/cache_dit_acceleration.md)** - Comprehensive Cache-DiT guide covering DBCache, TaylorSeer, SCM, and configuration parameters
- **[Tensor Parallelism](diffusion/parallelism_acceleration.md#tensor-parallelism)** - Guidance on how to enable TP for diffusion models.
- **[Sequence Parallelism](diffusion/parallelism_acceleration.md#sequence-parallelism)** - Guidance on how to set sequence parallelism with configuration.
- **[CFG-Parallel](diffusion/parallelism_acceleration.md#cfg-parallel)** - Guidance on how to set CFG-Parallel to run positive/negative branches across ranks.
- **[VAE Patch Parallelism](diffusion/parallelism_acceleration.md#vae-patch-parallelism)** - Guidance on how to reduce VAE memory via patch/tile parallelism.
23 changes: 16 additions & 7 deletions examples/offline_inference/text_to_image/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,6 @@ def parse_args() -> argparse.Namespace:
default=1,
help="Number of ready layers (blocks) to keep on GPU during generation.",
)
parser.add_argument(
"--tensor_parallel_size",
type=int,
default=1,
help="Number of GPUs used for tensor parallelism (TP) inside the DiT.",
)
parser.add_argument(
"--vae_use_slicing",
action="store_true",
Expand All @@ -134,6 +128,18 @@ def parse_args() -> argparse.Namespace:
action="store_true",
help="Enable VAE tiling for memory optimization.",
)
parser.add_argument(
"--tensor_parallel_size",
type=int,
default=1,
help="Number of GPUs used for tensor parallelism (TP) inside the DiT.",
)
parser.add_argument(
"--vae_patch_parallel_size",
type=int,
default=1,
help="Number of ranks used for VAE patch/tile parallelism (decode/encode).",
)
return parser.parse_args()


Expand Down Expand Up @@ -176,6 +182,7 @@ def main():
ring_degree=args.ring_degree,
cfg_parallel_size=args.cfg_parallel_size,
tensor_parallel_size=args.tensor_parallel_size,
vae_patch_parallel_size=args.vae_patch_parallel_size,
)

# Check if profiling is requested via environment variable
Expand Down Expand Up @@ -207,8 +214,10 @@ def main():
print(f" Cache backend: {args.cache_backend if args.cache_backend else 'None (no acceleration)'}")
print(
f" Parallel configuration: tensor_parallel_size={args.tensor_parallel_size}, "
f"ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}, cfg_parallel_size={args.cfg_parallel_size}"
f"ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}, cfg_parallel_size={args.cfg_parallel_size}, "
f"vae_patch_parallel_size={args.vae_patch_parallel_size}"
)
print(f" CPU offload: {args.enable_cpu_offload}")
print(f" Image size: {args.width}x{args.height}")
print(f"{'=' * 60}\n")

Expand Down
106 changes: 106 additions & 0 deletions tests/diffusion/distributed/test_vae_patch_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

"""Unit tests for VAE patch/tile parallelism helpers (CPU-only)."""

import pytest

from vllm_omni.diffusion.distributed import vae_patch_parallel as pp


class _DummyConfig:
def __init__(self, **attrs):
for k, v in attrs.items():
setattr(self, k, v)


class _DummyVae:
def __init__(self, *, config=None, **attrs):
self.config = config
for k, v in attrs.items():
setattr(self, k, v)


def test_get_vae_spatial_scale_factor_uses_block_out_channels_len_minus_1():
vae = _DummyVae(config=_DummyConfig(block_out_channels=[128, 256, 512, 512]))
assert pp._get_vae_spatial_scale_factor(vae) == 8

vae = _DummyVae(config=_DummyConfig(block_out_channels=[1, 2, 3, 4, 5]))
assert pp._get_vae_spatial_scale_factor(vae) == 16


def test_get_vae_spatial_scale_factor_defaults_to_8_on_missing_or_empty():
assert pp._get_vae_spatial_scale_factor(_DummyVae(config=_DummyConfig())) == 8
assert pp._get_vae_spatial_scale_factor(_DummyVae(config=_DummyConfig(block_out_channels=[]))) == 8
assert pp._get_vae_spatial_scale_factor(_DummyVae(config=None)) == 8


def test_get_vae_spatial_scale_factor_defaults_to_8_on_exception():
class _BrokenConfig:
@property
def block_out_channels(self):
raise RuntimeError("boom")

assert pp._get_vae_spatial_scale_factor(_DummyVae(config=_BrokenConfig())) == 8


@pytest.mark.parametrize(
("pp_size", "expected"),
[
(0, (1, 1)),
(1, (1, 1)),
(2, (1, 2)),
(3, (1, 3)),
(4, (2, 2)),
(6, (2, 3)),
(8, (2, 4)),
(12, (3, 4)),
(16, (4, 4)),
],
)
def test_factor_pp_grid(pp_size: int, expected: tuple[int, int]):
assert pp._factor_pp_grid(pp_size) == expected


def test_get_world_rank_pp_size(monkeypatch):
monkeypatch.setattr(pp.dist, "get_world_size", lambda _: 8)
monkeypatch.setattr(pp.dist, "get_rank", lambda _: 3)

world_size, rank, pp_size = pp._get_world_rank_pp_size(object(), 4)
assert (world_size, rank, pp_size) == (8, 3, 4)

world_size, rank, pp_size = pp._get_world_rank_pp_size(object(), 16)
assert (world_size, rank, pp_size) == (8, 3, 8)


def test_get_vae_out_channels_defaults_to_3():
assert pp._get_vae_out_channels(_DummyVae(config=None)) == 3
assert pp._get_vae_out_channels(_DummyVae(config=_DummyConfig())) == 3


def test_get_vae_out_channels_reads_config():
assert pp._get_vae_out_channels(_DummyVae(config=_DummyConfig(out_channels=4))) == 4
assert pp._get_vae_out_channels(_DummyVae(config=_DummyConfig(out_channels="5"))) == 5


def test_get_vae_tile_params_returns_none_if_missing():
assert pp._get_vae_tile_params(_DummyVae(tile_latent_min_size=None, tile_overlap_factor=0.25)) is None
assert pp._get_vae_tile_params(_DummyVae(tile_latent_min_size=128, tile_overlap_factor=None)) is None


def test_get_vae_tile_params_parses_types():
vae = _DummyVae(tile_latent_min_size="128", tile_overlap_factor="0.25")
assert pp._get_vae_tile_params(vae) == (128, 0.25)


def test_get_vae_tiling_params_returns_none_if_missing():
vae = _DummyVae(tile_latent_min_size=128, tile_overlap_factor=0.25, tile_sample_min_size=None)
assert pp._get_vae_tiling_params(vae) is None

vae = _DummyVae(tile_latent_min_size=None, tile_overlap_factor=0.25, tile_sample_min_size=1024)
assert pp._get_vae_tiling_params(vae) is None


def test_get_vae_tiling_params_parses_types():
vae = _DummyVae(tile_latent_min_size="128", tile_overlap_factor="0.25", tile_sample_min_size="1024")
assert pp._get_vae_tiling_params(vae) == (128, 0.25, 1024)
Loading