@@ -990,7 +990,58 @@ def relative_l1_distance(last_tensor, current_tensor):
990990 relative_l1_distance = l1_distance / norm
991991 return relative_l1_distance .to (torch .float32 )
992992
993- def teacache_wanvideo_forward_orig (self , x , t , context , clip_fea = None , freqs = None , transformer_options = {}, ** kwargs ):
993+ @torch .compiler .disable ()
994+ def tea_cache (self , x , e0 , e , transformer_options ):
995+ #teacache for cond and uncond separately
996+ rel_l1_thresh = transformer_options ["rel_l1_thresh" ]
997+
998+ is_cond = True if transformer_options ["cond_or_uncond" ] == [0 ] else False
999+
1000+ should_calc = True
1001+ suffix = "cond" if is_cond else "uncond"
1002+
1003+ # Init cache dict if not exists
1004+ if not hasattr (self , 'teacache_state' ):
1005+ self .teacache_state = {
1006+ 'cond' : {'accumulated_rel_l1_distance' : 0 , 'prev_input' : None ,
1007+ 'teacache_skipped_steps' : 0 , 'previous_residual' : None },
1008+ 'uncond' : {'accumulated_rel_l1_distance' : 0 , 'prev_input' : None ,
1009+ 'teacache_skipped_steps' : 0 , 'previous_residual' : None }
1010+ }
1011+ logging .info ("\n TeaCache: Initialized" )
1012+
1013+ cache = self .teacache_state [suffix ]
1014+
1015+ if cache ['prev_input' ] is not None :
1016+ if transformer_options ["coefficients" ] == []:
1017+ temb_relative_l1 = relative_l1_distance (cache ['prev_input' ], e0 )
1018+ curr_acc_dist = cache ['accumulated_rel_l1_distance' ] + temb_relative_l1
1019+ else :
1020+ rescale_func = np .poly1d (transformer_options ["coefficients" ])
1021+ curr_acc_dist = cache ['accumulated_rel_l1_distance' ] + rescale_func (((e - cache ['prev_input' ]).abs ().mean () / cache ['prev_input' ].abs ().mean ()).cpu ().item ())
1022+ try :
1023+ if curr_acc_dist < rel_l1_thresh :
1024+ should_calc = False
1025+ cache ['accumulated_rel_l1_distance' ] = curr_acc_dist
1026+ else :
1027+ should_calc = True
1028+ cache ['accumulated_rel_l1_distance' ] = 0
1029+ except :
1030+ should_calc = True
1031+ cache ['accumulated_rel_l1_distance' ] = 0
1032+
1033+ if transformer_options ["coefficients" ] == []:
1034+ cache ['prev_input' ] = e0 .clone ().detach ()
1035+ else :
1036+ cache ['prev_input' ] = e .clone ().detach ()
1037+
1038+ if not should_calc :
1039+ x += cache ['previous_residual' ].to (x .device )
1040+ cache ['teacache_skipped_steps' ] += 1
1041+ #print(f"TeaCache: Skipping {suffix} step")
1042+ return should_calc , cache
1043+
1044+ def teacache_wanvideo_vace_forward_orig (self , x , t , context , vace_context , vace_strength , clip_fea = None , freqs = None , transformer_options = {}, ** kwargs ):
9941045 # embeddings
9951046 x = self .patch_embedding (x .float ()).to (x .dtype )
9961047 grid_sizes = x .shape [2 :]
@@ -1003,69 +1054,88 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non
10031054
10041055 # context
10051056 context = self .text_embedding (context )
1006- if clip_fea is not None and self .img_emb is not None :
1007- context_clip = self .img_emb (clip_fea ) # bs x 257 x dim
1008- context = torch .concat ([context_clip , context ], dim = 1 )
10091057
1010- @ torch . compiler . disable ()
1011- def tea_cache ( x , e0 , e , kwargs ) :
1012- #teacache for cond and uncond separately
1013- rel_l1_thresh = transformer_options [ "rel_l1_thresh" ]
1014-
1015- is_cond = True if transformer_options [ "cond_or_uncond" ] == [ 0 ] else False
1058+ context_img_len = None
1059+ if clip_fea is not None :
1060+ if self . img_emb is not None :
1061+ context_clip = self . img_emb ( clip_fea ) # bs x 257 x dim
1062+ context = torch . concat ([ context_clip , context ], dim = 1 )
1063+ context_img_len = clip_fea . shape [ - 2 ]
10161064
1017- should_calc = True
1018- suffix = "cond" if is_cond else "uncond"
1019-
1020- # Init cache dict if not exists
1021- if not hasattr (self , 'teacache_state' ):
1022- self .teacache_state = {
1023- 'cond' : {'accumulated_rel_l1_distance' : 0 , 'prev_input' : None ,
1024- 'teacache_skipped_steps' : 0 , 'previous_residual' : None },
1025- 'uncond' : {'accumulated_rel_l1_distance' : 0 , 'prev_input' : None ,
1026- 'teacache_skipped_steps' : 0 , 'previous_residual' : None }
1027- }
1028- logging .info ("\n TeaCache: Initialized" )
1065+ orig_shape = list (vace_context .shape )
1066+ vace_context = vace_context .movedim (0 , 1 ).reshape ([- 1 ] + orig_shape [2 :])
1067+ c = self .vace_patch_embedding (vace_context .float ()).to (vace_context .dtype )
1068+ c = c .flatten (2 ).transpose (1 , 2 )
1069+ c = list (c .split (orig_shape [0 ], dim = 0 ))
10291070
1030- cache = self .teacache_state [suffix ]
1071+ if not transformer_options :
1072+ raise RuntimeError ("Can't access transformer_options, this requires ComfyUI nightly version from Mar 14, 2025 or later" )
10311073
1032- if cache ['prev_input' ] is not None :
1033- if transformer_options ["coefficients" ] == []:
1034- temb_relative_l1 = relative_l1_distance (cache ['prev_input' ], e0 )
1035- curr_acc_dist = cache ['accumulated_rel_l1_distance' ] + temb_relative_l1
1074+ teacache_enabled = transformer_options .get ("teacache_enabled" , False )
1075+ if not teacache_enabled :
1076+ should_calc = True
1077+ else :
1078+ should_calc , cache = tea_cache (self , x , e0 , e , transformer_options )
1079+
1080+ if should_calc :
1081+ original_x = x .clone ().detach ()
1082+ patches_replace = transformer_options .get ("patches_replace" , {})
1083+ blocks_replace = patches_replace .get ("dit" , {})
1084+ for i , block in enumerate (self .blocks ):
1085+ if ("double_block" , i ) in blocks_replace :
1086+ def block_wrap (args ):
1087+ out = {}
1088+ out ["img" ] = block (args ["img" ], context = args ["txt" ], e = args ["vec" ], freqs = args ["pe" ], context_img_len = context_img_len )
1089+ return out
1090+ out = blocks_replace [("double_block" , i )]({"img" : x , "txt" : context , "vec" : e0 , "pe" : freqs }, {"original_block" : block_wrap , "transformer_options" : transformer_options })
1091+ x = out ["img" ]
10361092 else :
1037- rescale_func = np .poly1d (transformer_options ["coefficients" ])
1038- curr_acc_dist = cache ['accumulated_rel_l1_distance' ] + rescale_func (((e - cache ['prev_input' ]).abs ().mean () / cache ['prev_input' ].abs ().mean ()).cpu ().item ())
1039- try :
1040- if curr_acc_dist < rel_l1_thresh :
1041- should_calc = False
1042- cache ['accumulated_rel_l1_distance' ] = curr_acc_dist
1043- else :
1044- should_calc = True
1045- cache ['accumulated_rel_l1_distance' ] = 0
1046- except :
1047- should_calc = True
1048- cache ['accumulated_rel_l1_distance' ] = 0
1093+ x = block (x , e = e0 , freqs = freqs , context = context , context_img_len = context_img_len )
10491094
1050- if transformer_options ["coefficients" ] == []:
1051- cache ['prev_input' ] = e0 .clone ().detach ()
1052- else :
1053- cache ['prev_input' ] = e .clone ().detach ()
1095+ ii = self .vace_layers_mapping .get (i , None )
1096+ if ii is not None :
1097+ for iii in range (len (c )):
1098+ c_skip , c [iii ] = self .vace_blocks [ii ](c [iii ], x = original_x , e = e0 , freqs = freqs , context = context , context_img_len = context_img_len )
1099+ x += c_skip * vace_strength [iii ]
1100+ del c_skip
1101+
1102+ if teacache_enabled :
1103+ cache ['previous_residual' ] = (x - original_x ).to (transformer_options ["teacache_device" ])
1104+
1105+ # head
1106+ x = self .head (x , e )
1107+
1108+ # unpatchify
1109+ x = self .unpatchify (x , grid_sizes )
1110+ return x
1111+
1112+ def teacache_wanvideo_forward_orig (self , x , t , context , clip_fea = None , freqs = None , transformer_options = {}, ** kwargs ):
1113+ # embeddings
1114+ x = self .patch_embedding (x .float ()).to (x .dtype )
1115+ grid_sizes = x .shape [2 :]
1116+ x = x .flatten (2 ).transpose (1 , 2 )
1117+
1118+ # time embeddings
1119+ e = self .time_embedding (
1120+ sinusoidal_embedding_1d (self .freq_dim , t ).to (dtype = x [0 ].dtype ))
1121+ e0 = self .time_projection (e ).unflatten (1 , (6 , self .dim ))
1122+
1123+ # context
1124+ context = self .text_embedding (context )
1125+
1126+ context_img_len = None
1127+ if clip_fea is not None :
1128+ if self .img_emb is not None :
1129+ context_clip = self .img_emb (clip_fea ) # bs x 257 x dim
1130+ context = torch .concat ([context_clip , context ], dim = 1 )
1131+ context_img_len = clip_fea .shape [- 2 ]
10541132
1055- if not should_calc :
1056- x += cache ['previous_residual' ].to (x .device )
1057- cache ['teacache_skipped_steps' ] += 1
1058- #print(f"TeaCache: Skipping {suffix} step")
1059- return should_calc , cache
1060-
1061- if not transformer_options :
1062- raise RuntimeError ("Can't access transformer_options, this requires ComfyUI nightly version from Mar 14, 2025 or later" )
10631133
10641134 teacache_enabled = transformer_options .get ("teacache_enabled" , False )
10651135 if not teacache_enabled :
10661136 should_calc = True
10671137 else :
1068- should_calc , cache = tea_cache (x , e0 , e , kwargs )
1138+ should_calc , cache = tea_cache (self , x , e0 , e , transformer_options )
10691139
10701140 if should_calc :
10711141 original_x = x .clone ().detach ()
@@ -1075,12 +1145,12 @@ def tea_cache(x, e0, e, kwargs):
10751145 if ("double_block" , i ) in blocks_replace :
10761146 def block_wrap (args ):
10771147 out = {}
1078- out ["img" ] = block (args ["img" ], context = args ["txt" ], e = args ["vec" ], freqs = args ["pe" ])
1148+ out ["img" ] = block (args ["img" ], context = args ["txt" ], e = args ["vec" ], freqs = args ["pe" ], context_img_len = context_img_len )
10791149 return out
10801150 out = blocks_replace [("double_block" , i )]({"img" : x , "txt" : context , "vec" : e0 , "pe" : freqs }, {"original_block" : block_wrap , "transformer_options" : transformer_options })
10811151 x = out ["img" ]
10821152 else :
1083- x = block (x , e = e0 , freqs = freqs , context = context )
1153+ x = block (x , e = e0 , freqs = freqs , context = context , context_img_len = context_img_len )
10841154
10851155 if teacache_enabled :
10861156 cache ['previous_residual' ] = (x - original_x ).to (transformer_options ["teacache_device" ])
@@ -1206,9 +1276,10 @@ def unet_wrapper_function(model_function, kwargs):
12061276 if start_percent <= current_percent <= end_percent :
12071277 c ["transformer_options" ]["teacache_enabled" ] = True
12081278
1279+ forward_function = teacache_wanvideo_vace_forward_orig if hasattr (diffusion_model , "vace_layers" ) else teacache_wanvideo_forward_orig
12091280 context = patch .multiple (
12101281 diffusion_model ,
1211- forward_orig = teacache_wanvideo_forward_orig .__get__ (diffusion_model , diffusion_model .__class__ )
1282+ forward_orig = forward_function .__get__ (diffusion_model , diffusion_model .__class__ )
12121283 )
12131284
12141285 with context :
0 commit comments