Skip to content

Commit e96a028

Browse files
committed
Support VACE with TeaCache
1 parent 5736669 commit e96a028

File tree

1 file changed

+126
-55
lines changed

1 file changed

+126
-55
lines changed

nodes/model_optimization_nodes.py

Lines changed: 126 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -990,7 +990,58 @@ def relative_l1_distance(last_tensor, current_tensor):
990990
relative_l1_distance = l1_distance / norm
991991
return relative_l1_distance.to(torch.float32)
992992

993-
def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, **kwargs):
993+
@torch.compiler.disable()
994+
def tea_cache(self, x, e0, e, transformer_options):
995+
#teacache for cond and uncond separately
996+
rel_l1_thresh = transformer_options["rel_l1_thresh"]
997+
998+
is_cond = True if transformer_options["cond_or_uncond"] == [0] else False
999+
1000+
should_calc = True
1001+
suffix = "cond" if is_cond else "uncond"
1002+
1003+
# Init cache dict if not exists
1004+
if not hasattr(self, 'teacache_state'):
1005+
self.teacache_state = {
1006+
'cond': {'accumulated_rel_l1_distance': 0, 'prev_input': None,
1007+
'teacache_skipped_steps': 0, 'previous_residual': None},
1008+
'uncond': {'accumulated_rel_l1_distance': 0, 'prev_input': None,
1009+
'teacache_skipped_steps': 0, 'previous_residual': None}
1010+
}
1011+
logging.info("\nTeaCache: Initialized")
1012+
1013+
cache = self.teacache_state[suffix]
1014+
1015+
if cache['prev_input'] is not None:
1016+
if transformer_options["coefficients"] == []:
1017+
temb_relative_l1 = relative_l1_distance(cache['prev_input'], e0)
1018+
curr_acc_dist = cache['accumulated_rel_l1_distance'] + temb_relative_l1
1019+
else:
1020+
rescale_func = np.poly1d(transformer_options["coefficients"])
1021+
curr_acc_dist = cache['accumulated_rel_l1_distance'] + rescale_func(((e-cache['prev_input']).abs().mean() / cache['prev_input'].abs().mean()).cpu().item())
1022+
try:
1023+
if curr_acc_dist < rel_l1_thresh:
1024+
should_calc = False
1025+
cache['accumulated_rel_l1_distance'] = curr_acc_dist
1026+
else:
1027+
should_calc = True
1028+
cache['accumulated_rel_l1_distance'] = 0
1029+
except:
1030+
should_calc = True
1031+
cache['accumulated_rel_l1_distance'] = 0
1032+
1033+
if transformer_options["coefficients"] == []:
1034+
cache['prev_input'] = e0.clone().detach()
1035+
else:
1036+
cache['prev_input'] = e.clone().detach()
1037+
1038+
if not should_calc:
1039+
x += cache['previous_residual'].to(x.device)
1040+
cache['teacache_skipped_steps'] += 1
1041+
#print(f"TeaCache: Skipping {suffix} step")
1042+
return should_calc, cache
1043+
1044+
def teacache_wanvideo_vace_forward_orig(self, x, t, context, vace_context, vace_strength, clip_fea=None, freqs=None, transformer_options={}, **kwargs):
9941045
# embeddings
9951046
x = self.patch_embedding(x.float()).to(x.dtype)
9961047
grid_sizes = x.shape[2:]
@@ -1003,69 +1054,88 @@ def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=Non
10031054

10041055
# context
10051056
context = self.text_embedding(context)
1006-
if clip_fea is not None and self.img_emb is not None:
1007-
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
1008-
context = torch.concat([context_clip, context], dim=1)
10091057

1010-
@torch.compiler.disable()
1011-
def tea_cache(x, e0, e, kwargs):
1012-
#teacache for cond and uncond separately
1013-
rel_l1_thresh = transformer_options["rel_l1_thresh"]
1014-
1015-
is_cond = True if transformer_options["cond_or_uncond"] == [0] else False
1058+
context_img_len = None
1059+
if clip_fea is not None:
1060+
if self.img_emb is not None:
1061+
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
1062+
context = torch.concat([context_clip, context], dim=1)
1063+
context_img_len = clip_fea.shape[-2]
10161064

1017-
should_calc = True
1018-
suffix = "cond" if is_cond else "uncond"
1019-
1020-
# Init cache dict if not exists
1021-
if not hasattr(self, 'teacache_state'):
1022-
self.teacache_state = {
1023-
'cond': {'accumulated_rel_l1_distance': 0, 'prev_input': None,
1024-
'teacache_skipped_steps': 0, 'previous_residual': None},
1025-
'uncond': {'accumulated_rel_l1_distance': 0, 'prev_input': None,
1026-
'teacache_skipped_steps': 0, 'previous_residual': None}
1027-
}
1028-
logging.info("\nTeaCache: Initialized")
1065+
orig_shape = list(vace_context.shape)
1066+
vace_context = vace_context.movedim(0, 1).reshape([-1] + orig_shape[2:])
1067+
c = self.vace_patch_embedding(vace_context.float()).to(vace_context.dtype)
1068+
c = c.flatten(2).transpose(1, 2)
1069+
c = list(c.split(orig_shape[0], dim=0))
10291070

1030-
cache = self.teacache_state[suffix]
1071+
if not transformer_options:
1072+
raise RuntimeError("Can't access transformer_options, this requires ComfyUI nightly version from Mar 14, 2025 or later")
10311073

1032-
if cache['prev_input'] is not None:
1033-
if transformer_options["coefficients"] == []:
1034-
temb_relative_l1 = relative_l1_distance(cache['prev_input'], e0)
1035-
curr_acc_dist = cache['accumulated_rel_l1_distance'] + temb_relative_l1
1074+
teacache_enabled = transformer_options.get("teacache_enabled", False)
1075+
if not teacache_enabled:
1076+
should_calc = True
1077+
else:
1078+
should_calc, cache = tea_cache(self, x, e0, e, transformer_options)
1079+
1080+
if should_calc:
1081+
original_x = x.clone().detach()
1082+
patches_replace = transformer_options.get("patches_replace", {})
1083+
blocks_replace = patches_replace.get("dit", {})
1084+
for i, block in enumerate(self.blocks):
1085+
if ("double_block", i) in blocks_replace:
1086+
def block_wrap(args):
1087+
out = {}
1088+
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
1089+
return out
1090+
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap, "transformer_options": transformer_options})
1091+
x = out["img"]
10361092
else:
1037-
rescale_func = np.poly1d(transformer_options["coefficients"])
1038-
curr_acc_dist = cache['accumulated_rel_l1_distance'] + rescale_func(((e-cache['prev_input']).abs().mean() / cache['prev_input'].abs().mean()).cpu().item())
1039-
try:
1040-
if curr_acc_dist < rel_l1_thresh:
1041-
should_calc = False
1042-
cache['accumulated_rel_l1_distance'] = curr_acc_dist
1043-
else:
1044-
should_calc = True
1045-
cache['accumulated_rel_l1_distance'] = 0
1046-
except:
1047-
should_calc = True
1048-
cache['accumulated_rel_l1_distance'] = 0
1093+
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
10491094

1050-
if transformer_options["coefficients"] == []:
1051-
cache['prev_input'] = e0.clone().detach()
1052-
else:
1053-
cache['prev_input'] = e.clone().detach()
1095+
ii = self.vace_layers_mapping.get(i, None)
1096+
if ii is not None:
1097+
for iii in range(len(c)):
1098+
c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=original_x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
1099+
x += c_skip * vace_strength[iii]
1100+
del c_skip
1101+
1102+
if teacache_enabled:
1103+
cache['previous_residual'] = (x - original_x).to(transformer_options["teacache_device"])
1104+
1105+
# head
1106+
x = self.head(x, e)
1107+
1108+
# unpatchify
1109+
x = self.unpatchify(x, grid_sizes)
1110+
return x
1111+
1112+
def teacache_wanvideo_forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, **kwargs):
1113+
# embeddings
1114+
x = self.patch_embedding(x.float()).to(x.dtype)
1115+
grid_sizes = x.shape[2:]
1116+
x = x.flatten(2).transpose(1, 2)
1117+
1118+
# time embeddings
1119+
e = self.time_embedding(
1120+
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
1121+
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
1122+
1123+
# context
1124+
context = self.text_embedding(context)
1125+
1126+
context_img_len = None
1127+
if clip_fea is not None:
1128+
if self.img_emb is not None:
1129+
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
1130+
context = torch.concat([context_clip, context], dim=1)
1131+
context_img_len = clip_fea.shape[-2]
10541132

1055-
if not should_calc:
1056-
x += cache['previous_residual'].to(x.device)
1057-
cache['teacache_skipped_steps'] += 1
1058-
#print(f"TeaCache: Skipping {suffix} step")
1059-
return should_calc, cache
1060-
1061-
if not transformer_options:
1062-
raise RuntimeError("Can't access transformer_options, this requires ComfyUI nightly version from Mar 14, 2025 or later")
10631133

10641134
teacache_enabled = transformer_options.get("teacache_enabled", False)
10651135
if not teacache_enabled:
10661136
should_calc = True
10671137
else:
1068-
should_calc, cache = tea_cache(x, e0, e, kwargs)
1138+
should_calc, cache = tea_cache(self, x, e0, e, transformer_options)
10691139

10701140
if should_calc:
10711141
original_x = x.clone().detach()
@@ -1075,12 +1145,12 @@ def tea_cache(x, e0, e, kwargs):
10751145
if ("double_block", i) in blocks_replace:
10761146
def block_wrap(args):
10771147
out = {}
1078-
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
1148+
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
10791149
return out
10801150
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap, "transformer_options": transformer_options})
10811151
x = out["img"]
10821152
else:
1083-
x = block(x, e=e0, freqs=freqs, context=context)
1153+
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
10841154

10851155
if teacache_enabled:
10861156
cache['previous_residual'] = (x - original_x).to(transformer_options["teacache_device"])
@@ -1206,9 +1276,10 @@ def unet_wrapper_function(model_function, kwargs):
12061276
if start_percent <= current_percent <= end_percent:
12071277
c["transformer_options"]["teacache_enabled"] = True
12081278

1279+
forward_function = teacache_wanvideo_vace_forward_orig if hasattr(diffusion_model, "vace_layers") else teacache_wanvideo_forward_orig
12091280
context = patch.multiple(
12101281
diffusion_model,
1211-
forward_orig=teacache_wanvideo_forward_orig.__get__(diffusion_model, diffusion_model.__class__)
1282+
forward_orig=forward_function.__get__(diffusion_model, diffusion_model.__class__)
12121283
)
12131284

12141285
with context:

0 commit comments

Comments
 (0)