-
Notifications
You must be signed in to change notification settings - Fork 395
[Feat]: support VAE patch parallelism #756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
dongbo910220
wants to merge
37
commits into
vllm-project:main
Choose a base branch
from
dongbo910220:patch_vae_parallelism
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+860
−79
Open
Changes from 31 commits
Commits
Show all changes
37 commits
Select commit
Hold shift + click to select a range
4c7f60f
diffusion(z-image): add VAE patch parallel decode
dongbo910220 5bddc7c
tests/examples: cover VAE patch parallelism
dongbo910220 f72343f
examples: keep VAE slicing opt-in
dongbo910220 4aa8819
examples: keep VAE tiling opt-in
dongbo910220 01dc18c
diffusion(z-image): prototype DistVAE-style patch decode
dongbo910220 dfccf30
diffusion(z-image): extend patch decode and balance tiles
dongbo910220 4997ab9
diffusion(z-image): clarify patch decode comment
dongbo910220 3bb5cca
diffusion: refactor VAE patch parallelism (ref_dit)
dongbo910220 75c8618
diffusion: dedupe VAE patch-parallel helpers
dongbo910220 a8b0e96
tests: add unit coverage for VAE patch parallelism helpers
dongbo910220 6dd3e3a
diffusion: inject VAE decode profiling wrapper
dongbo910220 02dabb0
diffusion: remove VAE decode profiling hooks
dongbo910220 d730c04
diffusion: allowlist VAE patch parallel install
dongbo910220 b63544e
diffusion: drop env override for VAE patch parallel size
dongbo910220 432276f
diffusion: remove legacy vae_parallel_size group
dongbo910220 5ee4f97
diffusion: document and reorganize VAE patch parallelism
dongbo910220 147a690
docs: add tensor parallelism quickstart
dongbo910220 a7f5feb
diffusion: auto-enable VAE tiling for vae pp
dongbo910220 de1c634
docs: expand diffusion acceleration support table
dongbo910220 31bfb43
docs: add VAE patch parallel example and align tables
dongbo910220 1536fff
tests: relocate VAE patch parallel unit tests
dongbo910220 9ae119f
tests: align Z-Image TP e2e size with upstream
dongbo910220 20e1d3a
cleanup: move VAE pp unit test and remove empty vae pkg
dongbo910220 36869f6
style: apply ruff format
dongbo910220 61afcb8
Merge origin/main into patch_vae_parallelism
dongbo910220 69eeb79
tests: fix zimage platform checks
dongbo910220 d187eb9
Merge origin/main into patch_vae_parallelism
dongbo910220 802da55
fix: vllm 0.14 multimodal import
dongbo910220 2ee0c12
refactor: clarify VAE patch parallel naming
dongbo910220 d923536
test: always enforce eager in zimage e2e
dongbo910220 c95c0ff
Merge remote-tracking branch 'origin/main' into patch_vae_parallelism
dongbo910220 a8f1024
refactor: simplify vae patch parallel size access
dongbo910220 7f2e3ef
test: rename zimage parallel e2e file
dongbo910220 b954d2d
test: clarify zimage parallelism e2e coverage
dongbo910220 cdba974
test: enable compile for zimage e2e
dongbo910220 d7e55fb
test: trim zimage e2e docstring
dongbo910220 5541d13
test: drop eager guard in zimage e2e
dongbo910220 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,110 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| """Unit tests for VAE patch/tile parallelism helpers (CPU-only).""" | ||
|
|
||
| import pytest | ||
|
|
||
dongbo910220 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| from vllm_omni.diffusion.distributed import vae_patch_parallel as vae_patch_parallel | ||
|
|
||
|
|
||
| class _DummyConfig: | ||
| def __init__(self, **attrs): | ||
| for k, v in attrs.items(): | ||
| setattr(self, k, v) | ||
|
|
||
|
|
||
| class _DummyVae: | ||
| def __init__(self, *, config=None, **attrs): | ||
| self.config = config | ||
| for k, v in attrs.items(): | ||
| setattr(self, k, v) | ||
|
|
||
|
|
||
| def test_get_vae_spatial_scale_factor_uses_block_out_channels_len_minus_1(): | ||
| vae = _DummyVae(config=_DummyConfig(block_out_channels=[128, 256, 512, 512])) | ||
| assert vae_patch_parallel._get_vae_spatial_scale_factor(vae) == 8 | ||
|
|
||
| vae = _DummyVae(config=_DummyConfig(block_out_channels=[1, 2, 3, 4, 5])) | ||
| assert vae_patch_parallel._get_vae_spatial_scale_factor(vae) == 16 | ||
|
|
||
|
|
||
| def test_get_vae_spatial_scale_factor_defaults_to_8_on_missing_or_empty(): | ||
| assert vae_patch_parallel._get_vae_spatial_scale_factor(_DummyVae(config=_DummyConfig())) == 8 | ||
| assert vae_patch_parallel._get_vae_spatial_scale_factor(_DummyVae(config=_DummyConfig(block_out_channels=[]))) == 8 | ||
| assert vae_patch_parallel._get_vae_spatial_scale_factor(_DummyVae(config=None)) == 8 | ||
|
|
||
|
|
||
| def test_get_vae_spatial_scale_factor_defaults_to_8_on_exception(): | ||
| class _BrokenConfig: | ||
| @property | ||
| def block_out_channels(self): | ||
| raise RuntimeError("boom") | ||
|
|
||
| assert vae_patch_parallel._get_vae_spatial_scale_factor(_DummyVae(config=_BrokenConfig())) == 8 | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| ("pp_size", "expected"), | ||
| [ | ||
| (0, (1, 1)), | ||
| (1, (1, 1)), | ||
| (2, (1, 2)), | ||
| (3, (1, 3)), | ||
| (4, (2, 2)), | ||
| (6, (2, 3)), | ||
| (8, (2, 4)), | ||
| (12, (3, 4)), | ||
| (16, (4, 4)), | ||
| ], | ||
| ) | ||
| def test_factor_pp_grid(pp_size: int, expected: tuple[int, int]): | ||
| assert vae_patch_parallel._factor_pp_grid(pp_size) == expected | ||
|
|
||
|
|
||
| def test_get_world_rank_pp_size(monkeypatch): | ||
| monkeypatch.setattr(vae_patch_parallel.dist, "get_world_size", lambda _: 8) | ||
| monkeypatch.setattr(vae_patch_parallel.dist, "get_rank", lambda _: 3) | ||
|
|
||
| world_size, rank, pp_size = vae_patch_parallel._get_world_rank_pp_size(object(), 4) | ||
| assert (world_size, rank, pp_size) == (8, 3, 4) | ||
|
|
||
| world_size, rank, pp_size = vae_patch_parallel._get_world_rank_pp_size(object(), 16) | ||
| assert (world_size, rank, pp_size) == (8, 3, 8) | ||
|
|
||
|
|
||
| def test_get_vae_out_channels_defaults_to_3(): | ||
| assert vae_patch_parallel._get_vae_out_channels(_DummyVae(config=None)) == 3 | ||
| assert vae_patch_parallel._get_vae_out_channels(_DummyVae(config=_DummyConfig())) == 3 | ||
|
|
||
|
|
||
| def test_get_vae_out_channels_reads_config(): | ||
| assert vae_patch_parallel._get_vae_out_channels(_DummyVae(config=_DummyConfig(out_channels=4))) == 4 | ||
| assert vae_patch_parallel._get_vae_out_channels(_DummyVae(config=_DummyConfig(out_channels="5"))) == 5 | ||
|
|
||
|
|
||
| def test_get_vae_tile_params_returns_none_if_missing(): | ||
| assert ( | ||
| vae_patch_parallel._get_vae_tile_params(_DummyVae(tile_latent_min_size=None, tile_overlap_factor=0.25)) is None | ||
| ) | ||
| assert ( | ||
| vae_patch_parallel._get_vae_tile_params(_DummyVae(tile_latent_min_size=128, tile_overlap_factor=None)) is None | ||
| ) | ||
|
|
||
|
|
||
| def test_get_vae_tile_params_parses_types(): | ||
| vae = _DummyVae(tile_latent_min_size="128", tile_overlap_factor="0.25") | ||
| assert vae_patch_parallel._get_vae_tile_params(vae) == (128, 0.25) | ||
|
|
||
|
|
||
| def test_get_vae_tiling_params_returns_none_if_missing(): | ||
| vae = _DummyVae(tile_latent_min_size=128, tile_overlap_factor=0.25, tile_sample_min_size=None) | ||
| assert vae_patch_parallel._get_vae_tiling_params(vae) is None | ||
|
|
||
| vae = _DummyVae(tile_latent_min_size=None, tile_overlap_factor=0.25, tile_sample_min_size=1024) | ||
| assert vae_patch_parallel._get_vae_tiling_params(vae) is None | ||
|
|
||
|
|
||
| def test_get_vae_tiling_params_parses_types(): | ||
| vae = _DummyVae(tile_latent_min_size="128", tile_overlap_factor="0.25", tile_sample_min_size="1024") | ||
| assert vae_patch_parallel._get_vae_tiling_params(vae) == (128, 0.25, 1024) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.