11from typing import Literal
22
3+ import torch
34import torch .nn as nn
45from diffusers .models .autoencoders .autoencoder_kl_wan import WanAttentionBlock , WanCausalConv3d
56
67from tensorrt_llm ._torch .visual_gen .modules .vae import (
7- BaseParallelVAEAdapter ,
88 HaloExchangeConv ,
99 HaloExchangeConv2dStride2 ,
1010 ParallelVaeAttentionBlock ,
1111)
12+ from tensorrt_llm ._torch .visual_gen .modules .vae .parallel_vae_interface import (
13+ ParallelVAEBase ,
14+ SplitSpec ,
15+ )
1216from tensorrt_llm ._torch .visual_gen .utils import as_tuple
1317
1418
@@ -26,29 +30,49 @@ def forward(self, x, cache_x=None, *args, **kwargs):
2630 return self ._strip_halo (result )
2731
2832
29- class WanParallelVAEAdapter ( BaseParallelVAEAdapter ):
30- """Parallel VAE adapter for ``AutoencoderKLWan``."""
33+ class ParallelVAE_Wan ( ParallelVAEBase ):
34+ """Parallel VAE wrapper for ``AutoencoderKLWan``."""
3135
32- def _get_chunk_dims (self , split_dim : Literal ["height" , "width" ]) -> dict :
36+ @staticmethod
37+ def make_spec (split_dim : Literal ["height" , "width" ]) -> SplitSpec :
3338 # WAN tensor shapes:
34- # 5D latent/video : (B, C, T, H, W) → H=dim3, W=dim4
35- # 4D per-frame : (B*T, C, H, W) → H=dim2, W=dim3
36- # 5D attention in : (B, C, T, H, W) → H=dim3, W=dim4
39+ # 5D latent/video : (B, C, T, H, W) -> H=dim3, W=dim4
40+ # 4D per-frame : (B*T, C, H, W) -> H=dim2, W=dim3
41+ # 5D attention in : (B, C, T, H, W) -> H=dim3, W=dim4
3742 if split_dim == "height" :
38- return { "input" : 3 , "conv3d" : 3 , "conv2d" : 2 , "attn" : 3 }
39- elif split_dim == "width" :
40- return { "input" : 4 , "conv3d" : 4 , "conv2d" : 3 , "attn" : 4 }
43+ return SplitSpec ( split_dim , input_dim = 3 , conv3d_dim = 3 , conv2d_dim = 2 , attn_dim = 3 )
44+ if split_dim == "width" :
45+ return SplitSpec ( split_dim , input_dim = 4 , conv3d_dim = 4 , conv2d_dim = 3 , attn_dim = 4 )
4146 raise ValueError (f"Invalid split_dim: { split_dim } " )
4247
43- def _parallelize_decoder (self ) -> None :
44- self ._replace_conv3d (self .vae .decoder )
45- self ._replace_attention (self .vae .decoder )
46- self ._replace_resample_conv2d (self .vae .decoder )
47-
48- def _parallelize_encoder (self ) -> None :
49- self ._replace_conv3d (self .vae .encoder )
50- self ._replace_attention (self .vae .encoder )
51- self ._replace_resample_conv2d_stride2 (self .vae .encoder )
48+ # ------------------------------------------------------------------
49+ # encode / decode
50+ # ------------------------------------------------------------------
51+
52+ def _encode_impl (self , x : torch .Tensor , ** kwargs ) -> torch .Tensor :
53+ x_local , _ = self ._split_tensor (x )
54+ z_local = self .vae_backend .encode (x_local , ** kwargs )
55+ if isinstance (z_local , (tuple , list )):
56+ z_local = z_local [0 ]
57+ return self ._gather_tensor (z_local )
58+
59+ def _decode_impl (self , z : torch .Tensor , ** kwargs ) -> torch .Tensor :
60+ z_local , _ = self ._split_tensor (z )
61+ out = self .vae_backend .decode (z_local , ** kwargs )
62+ x_local = out [0 ] if isinstance (out , (tuple , list )) else out
63+ return self ._gather_tensor (x_local )
64+
65+ # ------------------------------------------------------------------
66+ # Module parallelisation
67+ # ------------------------------------------------------------------
68+
69+ def _parallelize_modules (self ) -> None :
70+ self ._replace_conv3d (self .vae_backend .decoder )
71+ self ._replace_attention (self .vae_backend .decoder )
72+ self ._replace_resample_conv2d (self .vae_backend .decoder )
73+ self ._replace_conv3d (self .vae_backend .encoder )
74+ self ._replace_attention (self .vae_backend .encoder )
75+ self ._replace_resample_conv2d_stride2 (self .vae_backend .encoder )
5276
5377 def _replace_conv3d (self , model : nn .Module ) -> None :
5478 """Replace WanCausalConv3d (kernel > 1) with WanCausalConvHalo."""
@@ -63,15 +87,15 @@ def _replace_conv3d(self, model: nn.Module) -> None:
6387 name ,
6488 WanCausalConvHalo (
6589 module ,
66- self .chunk_dims [ "conv3d" ] ,
67- self .adj_groups ,
90+ self .spec . conv3d_dim ,
91+ self ._adj_groups ,
6892 self .rank ,
6993 self .world_size ,
7094 ),
7195 )
7296
7397 def _replace_attention (self , model : nn .Module ) -> None :
74- """Replace WanAttentionBlock with GatherAttention ."""
98+ """Replace WanAttentionBlock with parallel gather-attention ."""
7599 targets = [
76100 (name , module )
77101 for name , module in model .named_modules ()
@@ -83,19 +107,14 @@ def _replace_attention(self, model: nn.Module) -> None:
83107 name ,
84108 ParallelVaeAttentionBlock (
85109 module ,
86- self .chunk_dims [ "attn" ] ,
110+ self .spec . attn_dim ,
87111 self .rank ,
88112 self .world_size ,
89113 ),
90114 )
91115
92116 def _replace_resample_conv2d (self , model : nn .Module ) -> None :
93- """Replace stride-1 Conv2d inside WanResample upsample paths.
94-
95- WanResample.resample for upsample modes is:
96- Sequential(WanUpsample, Conv2d(dim, out, 3, padding=1))
97- The Conv2d is a standard 2D conv on per-frame data (B*T, C, H, W).
98- """
117+ """Replace stride-1 Conv2d inside WanResample upsample paths."""
99118 targets = [
100119 (name , module )
101120 for name , module in model .named_modules ()
@@ -110,21 +129,15 @@ def _replace_resample_conv2d(self, model: nn.Module) -> None:
110129 name ,
111130 HaloExchangeConv (
112131 module ,
113- self .chunk_dims [ "conv2d" ] ,
114- self .adj_groups ,
132+ self .spec . conv2d_dim ,
133+ self ._adj_groups ,
115134 self .rank ,
116135 self .world_size ,
117136 ),
118137 )
119138
120139 def _replace_resample_conv2d_stride2 (self , model : nn .Module ) -> None :
121- """Replace stride-2 Conv2d inside WanResample downsample paths.
122-
123- WanResample.resample for downsample modes is:
124- Sequential(ZeroPad2d((0,1,0,1)), Conv2d(dim, dim, 3, stride=(2,2)))
125- We replace the entire Sequential with HaloExchangeConv2dStride2, which
126- absorbs the ZeroPad2d logic.
127- """
140+ """Replace stride-2 Conv2d inside WanResample downsample paths."""
128141 targets = [
129142 (name , module )
130143 for name , module in model .named_modules ()
@@ -142,8 +155,8 @@ def _replace_resample_conv2d_stride2(self, model: nn.Module) -> None:
142155 name ,
143156 HaloExchangeConv2dStride2 (
144157 conv_module ,
145- self .chunk_dims [ "conv2d" ] ,
146- self .adj_groups ,
158+ self .spec . conv2d_dim ,
159+ self ._adj_groups ,
147160 self .rank ,
148161 self .world_size ,
149162 pad_before_conv = pad_module .padding ,
0 commit comments