Skip to content

Commit 2578637

Browse files
authored
[None][refactor] parallel vae refactor (#12123)
1 parent be20657 commit 2578637

File tree

8 files changed

+246
-199
lines changed

8 files changed

+246
-199
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .parallel_vae import WanParallelVAEAdapter
1+
from .parallel_vae import ParallelVAE_Wan
22
from .pipeline_wan import WanPipeline
33
from .pipeline_wan_i2v import WanImageToVideoPipeline
44
from .transformer_wan import WanTransformer3DModel
@@ -7,5 +7,5 @@
77
"WanPipeline",
88
"WanImageToVideoPipeline",
99
"WanTransformer3DModel",
10-
"WanParallelVAEAdapter",
10+
"ParallelVAE_Wan",
1111
]

tensorrt_llm/_torch/visual_gen/models/wan/parallel_vae.py

Lines changed: 53 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
from typing import Literal
22

3+
import torch
34
import torch.nn as nn
45
from diffusers.models.autoencoders.autoencoder_kl_wan import WanAttentionBlock, WanCausalConv3d
56

67
from tensorrt_llm._torch.visual_gen.modules.vae import (
7-
BaseParallelVAEAdapter,
88
HaloExchangeConv,
99
HaloExchangeConv2dStride2,
1010
ParallelVaeAttentionBlock,
1111
)
12+
from tensorrt_llm._torch.visual_gen.modules.vae.parallel_vae_interface import (
13+
ParallelVAEBase,
14+
SplitSpec,
15+
)
1216
from tensorrt_llm._torch.visual_gen.utils import as_tuple
1317

1418

@@ -26,29 +30,49 @@ def forward(self, x, cache_x=None, *args, **kwargs):
2630
return self._strip_halo(result)
2731

2832

29-
class WanParallelVAEAdapter(BaseParallelVAEAdapter):
30-
"""Parallel VAE adapter for ``AutoencoderKLWan``."""
33+
class ParallelVAE_Wan(ParallelVAEBase):
34+
"""Parallel VAE wrapper for ``AutoencoderKLWan``."""
3135

32-
def _get_chunk_dims(self, split_dim: Literal["height", "width"]) -> dict:
36+
@staticmethod
37+
def make_spec(split_dim: Literal["height", "width"]) -> SplitSpec:
3338
# WAN tensor shapes:
34-
# 5D latent/video : (B, C, T, H, W) H=dim3, W=dim4
35-
# 4D per-frame : (B*T, C, H, W) H=dim2, W=dim3
36-
# 5D attention in : (B, C, T, H, W) H=dim3, W=dim4
39+
# 5D latent/video : (B, C, T, H, W) -> H=dim3, W=dim4
40+
# 4D per-frame : (B*T, C, H, W) -> H=dim2, W=dim3
41+
# 5D attention in : (B, C, T, H, W) -> H=dim3, W=dim4
3742
if split_dim == "height":
38-
return {"input": 3, "conv3d": 3, "conv2d": 2, "attn": 3}
39-
elif split_dim == "width":
40-
return {"input": 4, "conv3d": 4, "conv2d": 3, "attn": 4}
43+
return SplitSpec(split_dim, input_dim=3, conv3d_dim=3, conv2d_dim=2, attn_dim=3)
44+
if split_dim == "width":
45+
return SplitSpec(split_dim, input_dim=4, conv3d_dim=4, conv2d_dim=3, attn_dim=4)
4146
raise ValueError(f"Invalid split_dim: {split_dim}")
4247

43-
def _parallelize_decoder(self) -> None:
44-
self._replace_conv3d(self.vae.decoder)
45-
self._replace_attention(self.vae.decoder)
46-
self._replace_resample_conv2d(self.vae.decoder)
47-
48-
def _parallelize_encoder(self) -> None:
49-
self._replace_conv3d(self.vae.encoder)
50-
self._replace_attention(self.vae.encoder)
51-
self._replace_resample_conv2d_stride2(self.vae.encoder)
48+
# ------------------------------------------------------------------
49+
# encode / decode
50+
# ------------------------------------------------------------------
51+
52+
def _encode_impl(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
53+
x_local, _ = self._split_tensor(x)
54+
z_local = self.vae_backend.encode(x_local, **kwargs)
55+
if isinstance(z_local, (tuple, list)):
56+
z_local = z_local[0]
57+
return self._gather_tensor(z_local)
58+
59+
def _decode_impl(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
60+
z_local, _ = self._split_tensor(z)
61+
out = self.vae_backend.decode(z_local, **kwargs)
62+
x_local = out[0] if isinstance(out, (tuple, list)) else out
63+
return self._gather_tensor(x_local)
64+
65+
# ------------------------------------------------------------------
66+
# Module parallelisation
67+
# ------------------------------------------------------------------
68+
69+
def _parallelize_modules(self) -> None:
70+
self._replace_conv3d(self.vae_backend.decoder)
71+
self._replace_attention(self.vae_backend.decoder)
72+
self._replace_resample_conv2d(self.vae_backend.decoder)
73+
self._replace_conv3d(self.vae_backend.encoder)
74+
self._replace_attention(self.vae_backend.encoder)
75+
self._replace_resample_conv2d_stride2(self.vae_backend.encoder)
5276

5377
def _replace_conv3d(self, model: nn.Module) -> None:
5478
"""Replace WanCausalConv3d (kernel > 1) with WanCausalConvHalo."""
@@ -63,15 +87,15 @@ def _replace_conv3d(self, model: nn.Module) -> None:
6387
name,
6488
WanCausalConvHalo(
6589
module,
66-
self.chunk_dims["conv3d"],
67-
self.adj_groups,
90+
self.spec.conv3d_dim,
91+
self._adj_groups,
6892
self.rank,
6993
self.world_size,
7094
),
7195
)
7296

7397
def _replace_attention(self, model: nn.Module) -> None:
74-
"""Replace WanAttentionBlock with GatherAttention."""
98+
"""Replace WanAttentionBlock with parallel gather-attention."""
7599
targets = [
76100
(name, module)
77101
for name, module in model.named_modules()
@@ -83,19 +107,14 @@ def _replace_attention(self, model: nn.Module) -> None:
83107
name,
84108
ParallelVaeAttentionBlock(
85109
module,
86-
self.chunk_dims["attn"],
110+
self.spec.attn_dim,
87111
self.rank,
88112
self.world_size,
89113
),
90114
)
91115

92116
def _replace_resample_conv2d(self, model: nn.Module) -> None:
93-
"""Replace stride-1 Conv2d inside WanResample upsample paths.
94-
95-
WanResample.resample for upsample modes is:
96-
Sequential(WanUpsample, Conv2d(dim, out, 3, padding=1))
97-
The Conv2d is a standard 2D conv on per-frame data (B*T, C, H, W).
98-
"""
117+
"""Replace stride-1 Conv2d inside WanResample upsample paths."""
99118
targets = [
100119
(name, module)
101120
for name, module in model.named_modules()
@@ -110,21 +129,15 @@ def _replace_resample_conv2d(self, model: nn.Module) -> None:
110129
name,
111130
HaloExchangeConv(
112131
module,
113-
self.chunk_dims["conv2d"],
114-
self.adj_groups,
132+
self.spec.conv2d_dim,
133+
self._adj_groups,
115134
self.rank,
116135
self.world_size,
117136
),
118137
)
119138

120139
def _replace_resample_conv2d_stride2(self, model: nn.Module) -> None:
121-
"""Replace stride-2 Conv2d inside WanResample downsample paths.
122-
123-
WanResample.resample for downsample modes is:
124-
Sequential(ZeroPad2d((0,1,0,1)), Conv2d(dim, dim, 3, stride=(2,2)))
125-
We replace the entire Sequential with HaloExchangeConv2dStride2, which
126-
absorbs the ZeroPad2d logic.
127-
"""
140+
"""Replace stride-2 Conv2d inside WanResample downsample paths."""
128141
targets = [
129142
(name, module)
130143
for name, module in model.named_modules()
@@ -142,8 +155,8 @@ def _replace_resample_conv2d_stride2(self, model: nn.Module) -> None:
142155
name,
143156
HaloExchangeConv2dStride2(
144157
conv_module,
145-
self.chunk_dims["conv2d"],
146-
self.adj_groups,
158+
self.spec.conv2d_dim,
159+
self._adj_groups,
147160
self.rank,
148161
self.world_size,
149162
pad_before_conv=pad_module.padding,

tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import time
2-
from typing import Optional, Type
2+
from typing import Optional
33

44
import diffusers
55
import torch
@@ -17,7 +17,6 @@
1717
from tensorrt_llm._utils import nvtx_range
1818
from tensorrt_llm.logger import logger
1919

20-
from .parallel_vae import WanParallelVAEAdapter
2120
from .transformer_wan import WanTransformer3DModel
2221

2322
# Supported Wan T2V models:
@@ -129,10 +128,6 @@ def common_warmup_shapes(self) -> list:
129128
"""Return list of common warmup shapes for the pipeline."""
130129
return [(480, 832, 33), (480, 832, 81), (720, 1280, 81)]
131130

132-
@property
133-
def vae_adapter_class(self) -> Type[WanParallelVAEAdapter]:
134-
return WanParallelVAEAdapter
135-
136131
def _init_transformer(self) -> None:
137132
logger.info("Creating WAN transformer with quantization support...")
138133
self.transformer = WanTransformer3DModel(model_config=self.model_config)

tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import os
33
import time
4-
from typing import Optional, Tuple, Type, Union
4+
from typing import Optional, Tuple, Union
55

66
import diffusers
77
import PIL.Image
@@ -19,8 +19,6 @@
1919
from tensorrt_llm._torch.visual_gen.utils import postprocess_video_tensor
2020
from tensorrt_llm.logger import logger
2121

22-
from .parallel_vae import WanParallelVAEAdapter
23-
2422
# Supported Wan I2V 14B models:
2523
# - Wan2.1-I2V-14B-480P: Single-stage image-to-video
2624
# - Wan2.1-I2V-14B-720P: Single-stage image-to-video
@@ -150,10 +148,6 @@ def common_warmup_shapes(self) -> list:
150148
"""Return list of common warmup shapes for the pipeline."""
151149
return [(480, 832, 33), (480, 832, 81), (720, 1280, 81)]
152150

153-
@property
154-
def vae_adapter_class(self) -> Type[WanParallelVAEAdapter]:
155-
return WanParallelVAEAdapter
156-
157151
def _init_transformer(self) -> None:
158152
logger.info("Creating WAN I2V transformer with quantization support...")
159153
self.transformer = WanTransformer3DModel(model_config=self.model_config)
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from .attention import ParallelVaeAttentionBlock
22
from .conv import HaloExchangeConv, HaloExchangeConv2dStride2
33
from .norm import GroupNormParallel
4-
from .parallel_vae_interface import BaseParallelVAEAdapter
4+
from .parallel_vae_interface import ParallelVAEBase, ParallelVAEFactory, SplitSpec
55

66
__all__ = [
77
"ParallelVaeAttentionBlock",
88
"HaloExchangeConv",
99
"HaloExchangeConv2dStride2",
1010
"GroupNormParallel",
11-
"BaseParallelVAEAdapter",
11+
"ParallelVAEBase",
12+
"ParallelVAEFactory",
13+
"SplitSpec",
1214
]

0 commit comments

Comments
 (0)